Stunting_app / app.py
Syahhh01's picture
Update app.py
38f86e6 verified
import os
os.system("git lfs install")
os.system("git lfs pull")
import torch
import torch.nn as nn
from torchvision import transforms, models
from safetensors.torch import load_file
from PIL import Image
from io import BytesIO
import base64
import gradio as gr
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# ============================================================
# 1. Load Model
# ============================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "model_stunting.safetensors"
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file '{MODEL_PATH}' tidak ditemukan di Space.")
state_dict = load_file(MODEL_PATH)
class Dense121(nn.Module):
def __init__(self, num_classes=2, pretrained=False):
super().__init__()
self.dense121 = models.densenet121(pretrained=False)
in_features = self.dense121.classifier.in_features
self.dense121.classifier = nn.Linear(in_features, num_classes)
def forward(self, x):
return self.dense121(x)
model = Dense121(num_classes=2, pretrained=False)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
# ============================================================
# 2. Image Transform
# ============================================================
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
# ============================================================
# 3. Prediction Helper
# ============================================================
def predict_pil(img: Image.Image):
tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs).item()
labels = ["Tidak Stunting", "Stunting"]
return {
"prediction": labels[pred],
"confidence": float(probs[0][pred])
}
# ============================================================
# 4. FASTAPI for Flutter
# ============================================================
api = FastAPI()
# CORS
api.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@api.post("/run/predict")
async def predict_api(payload: dict):
if "image" not in payload:
return {"error": "Field 'image' tidak ditemukan"}
# decode base64
base64data = payload["image"].split(",")[-1]
image_bytes = base64.b64decode(base64data)
img = Image.open(BytesIO(image_bytes)).convert("RGB")
return predict_pil(img)
# ============================================================
# 5. GRADIO UI
# ============================================================
def gradio_predict(image):
return predict_pil(image)
gradio_ui = gr.Interface(
fn=gradio_predict,
inputs=gr.Image(type="pil", label="Upload Gambar Anak"),
outputs=gr.JSON(label="Hasil Prediksi"),
title="Prediksi Stunting dari Foto",
description="Upload foto anak untuk mendeteksi risiko stunting."
)
# ============================================================
# 6. MOUNT (ini yang benar)
# ============================================================
app = gr.mount_gradio_app(api, gradio_ui, path="/")