nilotpaldhar2004's picture
Update main.py
90c6890 verified
import os
import io
import logging
import time
from contextlib import asynccontextmanager
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
import uvicorn
# ── 1. CONFIGURATION ─────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(message)s")
logger = logging.getLogger("dermsight-api")
MODEL_PATH = os.getenv("MODEL_PATH", "best_resnet50_skin.pth")
CLASSES_PATH = os.getenv("CLASSES_PATH", "classes.npy")
PORT = int(os.getenv("PORT", 7860))
ml = {}
# Standard ImageNet Stats
transform = transforms.Compose([
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# ── 2. MODEL ARCHITECTURE ───────────────────────────────────────────────────
def build_model() -> nn.Module:
model = models.resnet50(weights=None)
num_features = model.fc.in_features
# Ensure these Dropout values (0.5, 0.4) match your LATEST training run!
model.fc = nn.Sequential(
nn.Linear(num_features, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, 7)
)
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Missing {MODEL_PATH}")
state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model.eval()
return model
# ── 3. LIFESPAN ─────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Loading AI Resources...")
try:
ml["classes"] = np.load(CLASSES_PATH, allow_pickle=True)
ml["model"] = build_model()
logger.info("Resources loaded successfully.")
except Exception as e:
logger.error(f"Startup failed: {e}")
yield
ml.clear()
# ── 4. API ──────────────────────────────────────────────────────────────────
app = FastAPI(title="DermSight PRO", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", include_in_schema=False)
async def serve_frontend():
if os.path.exists("index.html"):
return FileResponse("index.html")
return {"message": "DermSight PRO API is live."}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
model = ml.get("model")
classes = ml.get("classes")
if not model or classes is None:
raise HTTPException(status_code=503, detail="Model not ready.")
try:
t0 = time.perf_counter()
img_bytes = await file.read()
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(tensor)
probs = torch.nn.functional.softmax(outputs[0], dim=0)
conf, idx = torch.max(probs, 0)
# FORCE keys to lowercase so JavaScript 'mel' always matches
all_probabilities = {
str(classes[i]).lower(): round(float(probs[i]) * 100, 2)
for i in range(len(classes))
}
return {
"prediction": str(classes[idx.item()]), # Keeping display name as is
"confidence": f"{conf.item()*100:.2f}%",
"all_probabilities": all_probabilities,
"latency_ms": round((time.perf_counter() - t0) * 1000, 2)
}
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=PORT)