from fastapi import FastAPI, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from PIL import Image import torch, io from torchvision import transforms from model import MultiTaskResNet50, MultiTaskConvNeXt, find_last_conv2d from decision import final_decision from advanced_decision import ( mc_uncertainty, patch_consistency, final_decision_v2 ) from gradcam import GradCAM from typing import Optional app = FastAPI(title="Mold Detection API (ResNet + ConvNeXt)") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) device = "cuda" if torch.cuda.is_available() else "cpu" # ------------------ # Load baseline model (ResNet) # ------------------ # ------------------ # Load baseline model (ResNet) # ------------------ resnet_ckpt = torch.load( "resnet50_multitask_mold.pth", map_location=device ) # Handle different checkpoint formats if isinstance(resnet_ckpt, dict) and "model" in resnet_ckpt: resnet_state = resnet_ckpt["model"] resnet_classes = resnet_ckpt.get("classes", []) else: resnet_state = resnet_ckpt resnet_classes = [] resnet_num_classes = len(resnet_classes) if resnet_classes else 9 resnet_mold_idx = ( resnet_classes.index("mold") if resnet_classes and "mold" in resnet_classes else 4 ) resnet_model = MultiTaskResNet50(resnet_num_classes).to(device) resnet_model.load_state_dict(resnet_state) resnet_model.eval() # ------------------ # Load main model (ConvNeXt) # ------------------ # Expecting checkpoint with keys: # - "model": state_dict # - "classes": list of class names (length N, mold at some index) ckpt = torch.load("best_convnext_multitask.pth", map_location=device) classes = ckpt.get("classes") or [] num_classes = len(classes) if classes else 9 mold_idx = classes.index("mold") if classes else 4 model = MultiTaskConvNeXt(num_classes).to(device) model.load_state_dict(ckpt["model"]) model.eval() # ------------------ # Transforms # ------------------ transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] ) ]) # ------------------ # Grad-CAM target layer (computed, not stored in model state_dict) # ------------------ target_layer = find_last_conv2d(model.backbone) gradcam = GradCAM(model, target_layer) if target_layer is not None else None # ------------------ # DINO (lazy loaded) # ------------------ dino: Optional[object] = None mold_embs = None def ensure_dino(): global dino, mold_embs if dino is None: try: from dino import load_dino, build_embeddings except ModuleNotFoundError as e: # Local/dev env might not have optional deps like `datasets`. raise HTTPException( status_code=503, detail=( "DINO dependencies are not installed. " "Install extras with: pip install datasets scikit-learn" ), ) from e try: dino = load_dino(device) mold_embs = build_embeddings(dino, transform, device) except Exception as e: raise HTTPException( status_code=503, detail=f"Failed to initialize DINO reference embeddings: {e}", ) from e # ------------------ # API endpoints # ------------------ @app.post("/predict/v1") async def predict_v1(file: UploadFile): img = Image.open(io.BytesIO(await file.read())).convert("RGB") img_t = transform(img).to(device) with torch.no_grad(): out = model(img_t.unsqueeze(0)) cp = torch.softmax(out["class"], 1)[0] bp = torch.softmax(out["bio"], 1)[0] mold_p = cp[mold_idx].item() bio_p = bp[1].item() decision = final_decision(mold_p, bio_p) return { "decision": decision, "mold_probability": round(mold_p, 3), "biological_probability": round(bio_p, 3), } @app.post("/predict/v2") async def predict_v2(file: UploadFile): ensure_dino() # Import similarity lazily (only needed for v2) from dino import similarity img = Image.open(io.BytesIO(await file.read())).convert("RGB") img_t = transform(img).to(device) with torch.no_grad(): out = model(img_t.unsqueeze(0)) cp = torch.softmax(out["class"], 1)[0] bp = torch.softmax(out["bio"], 1)[0] mold_p = cp[mold_idx].item() bio_p = bp[1].item() mean_p, std_p = mc_uncertainty(model, img_t, mold_idx) patch_ratio = patch_consistency( model, img, transform, mold_idx, device ) dino_sim = similarity( dino, mold_embs, img, transform, device ) decision = final_decision_v2( mold_p, bio_p, std_p, patch_ratio, dino_sim ) return { "decision": decision, "model_outputs": { "mold_probability": round(mold_p, 3), "biological_probability": round(bio_p, 3), }, "confidence_checks": { "uncertainty": round(std_p, 3), "patch_ratio": round(patch_ratio, 3), "dino_similarity": round(dino_sim, 3), }, } @app.post("/explain/gradcam") async def explain_gradcam(file: UploadFile): img = Image.open(io.BytesIO(await file.read())).convert("RGB") img_t = transform(img).to(device) cam = gradcam.generate(img_t, mold_idx) return {"gradcam": cam.tolist()} @app.post("/predict/resnet") async def predict_resnet(file: UploadFile): img = Image.open(io.BytesIO(await file.read())).convert("RGB") img_t = transform(img).to(device) with torch.no_grad(): out = resnet_model(img_t.unsqueeze(0)) cp = torch.softmax(out["class"], 1)[0] bp = torch.softmax(out["bio"], 1)[0] mold_p = cp[resnet_mold_idx].item() bio_p = bp[1].item() decision = final_decision(mold_p, bio_p) return { "decision": decision, "mold_probability": round(mold_p, 3), "biological_probability": round(bio_p, 3), }