AdarshRajDS
Add ResNet baseline and ConvNeXt v2 backend Dockerfile 1
c46050c
from fastapi import FastAPI, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch, io
from torchvision import transforms
from model import MultiTaskResNet50, MultiTaskConvNeXt, find_last_conv2d
from decision import final_decision
from advanced_decision import (
mc_uncertainty,
patch_consistency,
final_decision_v2
)
from gradcam import GradCAM
from typing import Optional
app = FastAPI(title="Mold Detection API (ResNet + ConvNeXt)")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
device = "cuda" if torch.cuda.is_available() else "cpu"
# ------------------
# Load baseline model (ResNet)
# ------------------
# ------------------
# Load baseline model (ResNet)
# ------------------
resnet_ckpt = torch.load(
"resnet50_multitask_mold.pth",
map_location=device
)
# Handle different checkpoint formats
if isinstance(resnet_ckpt, dict) and "model" in resnet_ckpt:
resnet_state = resnet_ckpt["model"]
resnet_classes = resnet_ckpt.get("classes", [])
else:
resnet_state = resnet_ckpt
resnet_classes = []
resnet_num_classes = len(resnet_classes) if resnet_classes else 9
resnet_mold_idx = (
resnet_classes.index("mold")
if resnet_classes and "mold" in resnet_classes
else 4
)
resnet_model = MultiTaskResNet50(resnet_num_classes).to(device)
resnet_model.load_state_dict(resnet_state)
resnet_model.eval()
# ------------------
# Load main model (ConvNeXt)
# ------------------
# Expecting checkpoint with keys:
# - "model": state_dict
# - "classes": list of class names (length N, mold at some index)
ckpt = torch.load("best_convnext_multitask.pth", map_location=device)
classes = ckpt.get("classes") or []
num_classes = len(classes) if classes else 9
mold_idx = classes.index("mold") if classes else 4
model = MultiTaskConvNeXt(num_classes).to(device)
model.load_state_dict(ckpt["model"])
model.eval()
# ------------------
# Transforms
# ------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]
)
])
# ------------------
# Grad-CAM target layer (computed, not stored in model state_dict)
# ------------------
target_layer = find_last_conv2d(model.backbone)
gradcam = GradCAM(model, target_layer) if target_layer is not None else None
# ------------------
# DINO (lazy loaded)
# ------------------
dino: Optional[object] = None
mold_embs = None
def ensure_dino():
global dino, mold_embs
if dino is None:
try:
from dino import load_dino, build_embeddings
except ModuleNotFoundError as e:
# Local/dev env might not have optional deps like `datasets`.
raise HTTPException(
status_code=503,
detail=(
"DINO dependencies are not installed. "
"Install extras with: pip install datasets scikit-learn"
),
) from e
try:
dino = load_dino(device)
mold_embs = build_embeddings(dino, transform, device)
except Exception as e:
raise HTTPException(
status_code=503,
detail=f"Failed to initialize DINO reference embeddings: {e}",
) from e
# ------------------
# API endpoints
# ------------------
@app.post("/predict/v1")
async def predict_v1(file: UploadFile):
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
img_t = transform(img).to(device)
with torch.no_grad():
out = model(img_t.unsqueeze(0))
cp = torch.softmax(out["class"], 1)[0]
bp = torch.softmax(out["bio"], 1)[0]
mold_p = cp[mold_idx].item()
bio_p = bp[1].item()
decision = final_decision(mold_p, bio_p)
return {
"decision": decision,
"mold_probability": round(mold_p, 3),
"biological_probability": round(bio_p, 3),
}
@app.post("/predict/v2")
async def predict_v2(file: UploadFile):
ensure_dino()
# Import similarity lazily (only needed for v2)
from dino import similarity
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
img_t = transform(img).to(device)
with torch.no_grad():
out = model(img_t.unsqueeze(0))
cp = torch.softmax(out["class"], 1)[0]
bp = torch.softmax(out["bio"], 1)[0]
mold_p = cp[mold_idx].item()
bio_p = bp[1].item()
mean_p, std_p = mc_uncertainty(model, img_t, mold_idx)
patch_ratio = patch_consistency(
model, img, transform, mold_idx, device
)
dino_sim = similarity(
dino, mold_embs, img, transform, device
)
decision = final_decision_v2(
mold_p, bio_p, std_p, patch_ratio, dino_sim
)
return {
"decision": decision,
"model_outputs": {
"mold_probability": round(mold_p, 3),
"biological_probability": round(bio_p, 3),
},
"confidence_checks": {
"uncertainty": round(std_p, 3),
"patch_ratio": round(patch_ratio, 3),
"dino_similarity": round(dino_sim, 3),
},
}
@app.post("/explain/gradcam")
async def explain_gradcam(file: UploadFile):
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
img_t = transform(img).to(device)
cam = gradcam.generate(img_t, mold_idx)
return {"gradcam": cam.tolist()}
@app.post("/predict/resnet")
async def predict_resnet(file: UploadFile):
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
img_t = transform(img).to(device)
with torch.no_grad():
out = resnet_model(img_t.unsqueeze(0))
cp = torch.softmax(out["class"], 1)[0]
bp = torch.softmax(out["bio"], 1)[0]
mold_p = cp[resnet_mold_idx].item()
bio_p = bp[1].item()
decision = final_decision(mold_p, bio_p)
return {
"decision": decision,
"mold_probability": round(mold_p, 3),
"biological_probability": round(bio_p, 3),
}