Spaces:
Runtime error
Runtime error
| import os | |
| os.system("git lfs install") | |
| os.system("git lfs pull") | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms, models | |
| from safetensors.torch import load_file | |
| from PIL import Image | |
| from io import BytesIO | |
| import base64 | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| # ============================================================ | |
| # 1. Load Model | |
| # ============================================================ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_PATH = "model_stunting.safetensors" | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"Model file '{MODEL_PATH}' tidak ditemukan di Space.") | |
| state_dict = load_file(MODEL_PATH) | |
| class Dense121(nn.Module): | |
| def __init__(self, num_classes=2, pretrained=False): | |
| super().__init__() | |
| self.dense121 = models.densenet121(pretrained=False) | |
| in_features = self.dense121.classifier.in_features | |
| self.dense121.classifier = nn.Linear(in_features, num_classes) | |
| def forward(self, x): | |
| return self.dense121(x) | |
| model = Dense121(num_classes=2, pretrained=False) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| # ============================================================ | |
| # 2. Image Transform | |
| # ============================================================ | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ), | |
| ]) | |
| # ============================================================ | |
| # 3. Prediction Helper | |
| # ============================================================ | |
| def predict_pil(img: Image.Image): | |
| tensor = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = torch.softmax(logits, dim=1) | |
| pred = torch.argmax(probs).item() | |
| labels = ["Tidak Stunting", "Stunting"] | |
| return { | |
| "prediction": labels[pred], | |
| "confidence": float(probs[0][pred]) | |
| } | |
| # ============================================================ | |
| # 4. FASTAPI for Flutter | |
| # ============================================================ | |
| api = FastAPI() | |
| # CORS | |
| api.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def predict_api(payload: dict): | |
| if "image" not in payload: | |
| return {"error": "Field 'image' tidak ditemukan"} | |
| # decode base64 | |
| base64data = payload["image"].split(",")[-1] | |
| image_bytes = base64.b64decode(base64data) | |
| img = Image.open(BytesIO(image_bytes)).convert("RGB") | |
| return predict_pil(img) | |
| # ============================================================ | |
| # 5. GRADIO UI | |
| # ============================================================ | |
| def gradio_predict(image): | |
| return predict_pil(image) | |
| gradio_ui = gr.Interface( | |
| fn=gradio_predict, | |
| inputs=gr.Image(type="pil", label="Upload Gambar Anak"), | |
| outputs=gr.JSON(label="Hasil Prediksi"), | |
| title="Prediksi Stunting dari Foto", | |
| description="Upload foto anak untuk mendeteksi risiko stunting." | |
| ) | |
| # ============================================================ | |
| # 6. MOUNT (ini yang benar) | |
| # ============================================================ | |
| app = gr.mount_gradio_app(api, gradio_ui, path="/") | |