Syahhh01's picture
Update app.py
ad4fed5 verified
from fastapi import FastAPI, UploadFile, File
import torch
from safetensors.torch import load_file
from torchvision import transforms
from PIL import Image
import io
import torch.nn as nn
from torchvision import models
import numpy as np
app = FastAPI()
# =============== ROOT ROUTE ==================
@app.get("/")
async def root():
return {"message": "Stunting Detector API is running!"}
# =============== LOAD MODEL SAFETENSORS ==================
class Dense121(nn.Module):
def __init__(self, num_classes, pretrained=True):
super(Dense121, self).__init__()
if pretrained:
try:
weights = models.DenseNet121_Weights.IMAGENET1K_V1
self.dense121 = models.densenet121(weights=weights)
except:
self.dense121 = models.densenet121(pretrained=True)
else:
self.dense121 = models.densenet121(pretrained=False)
in_features = self.dense121.classifier.in_features
self.dense121.classifier = nn.Linear(in_features, num_classes)
def forward(self, x):
return self.dense121(x)
model = Dense121(num_classes=2)
state_dict = load_file("model_stunting.safetensors")
model.load_state_dict(state_dict)
model.eval()
# =============== IMAGE PREPROCESS ==================
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# =============== API ENDPOINT ==================
@app.post("/predict-image")
async def predict(file: UploadFile = File(...)):
img_bytes = await file.read()
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
tensor = preprocess(img).unsqueeze(0)
with torch.no_grad():
output = model(tensor)
probs = torch.softmax(output, dim=1)[0].cpu().numpy().tolist()
labels = ["normal", "stunting"]
pred_idx = int(np.argmax(probs))
pred_label = labels[pred_idx]
confidence = probs[pred_idx]
return {
"prediction": probs,
"label": pred_label,
"confidence": confidence
}