Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| import torch, io | |
| from torchvision import transforms | |
| from model import MultiTaskResNet50, MultiTaskConvNeXt, find_last_conv2d | |
| from decision import final_decision | |
| from advanced_decision import ( | |
| mc_uncertainty, | |
| patch_consistency, | |
| final_decision_v2 | |
| ) | |
| from gradcam import GradCAM | |
| from typing import Optional | |
| app = FastAPI(title="Mold Detection API (ResNet + ConvNeXt)") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ------------------ | |
| # Load baseline model (ResNet) | |
| # ------------------ | |
| # ------------------ | |
| # Load baseline model (ResNet) | |
| # ------------------ | |
| resnet_ckpt = torch.load( | |
| "resnet50_multitask_mold.pth", | |
| map_location=device | |
| ) | |
| # Handle different checkpoint formats | |
| if isinstance(resnet_ckpt, dict) and "model" in resnet_ckpt: | |
| resnet_state = resnet_ckpt["model"] | |
| resnet_classes = resnet_ckpt.get("classes", []) | |
| else: | |
| resnet_state = resnet_ckpt | |
| resnet_classes = [] | |
| resnet_num_classes = len(resnet_classes) if resnet_classes else 9 | |
| resnet_mold_idx = ( | |
| resnet_classes.index("mold") | |
| if resnet_classes and "mold" in resnet_classes | |
| else 4 | |
| ) | |
| resnet_model = MultiTaskResNet50(resnet_num_classes).to(device) | |
| resnet_model.load_state_dict(resnet_state) | |
| resnet_model.eval() | |
| # ------------------ | |
| # Load main model (ConvNeXt) | |
| # ------------------ | |
| # Expecting checkpoint with keys: | |
| # - "model": state_dict | |
| # - "classes": list of class names (length N, mold at some index) | |
| ckpt = torch.load("best_convnext_multitask.pth", map_location=device) | |
| classes = ckpt.get("classes") or [] | |
| num_classes = len(classes) if classes else 9 | |
| mold_idx = classes.index("mold") if classes else 4 | |
| model = MultiTaskConvNeXt(num_classes).to(device) | |
| model.load_state_dict(ckpt["model"]) | |
| model.eval() | |
| # ------------------ | |
| # Transforms | |
| # ------------------ | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| [0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| # ------------------ | |
| # Grad-CAM target layer (computed, not stored in model state_dict) | |
| # ------------------ | |
| target_layer = find_last_conv2d(model.backbone) | |
| gradcam = GradCAM(model, target_layer) if target_layer is not None else None | |
| # ------------------ | |
| # DINO (lazy loaded) | |
| # ------------------ | |
| dino: Optional[object] = None | |
| mold_embs = None | |
| def ensure_dino(): | |
| global dino, mold_embs | |
| if dino is None: | |
| try: | |
| from dino import load_dino, build_embeddings | |
| except ModuleNotFoundError as e: | |
| # Local/dev env might not have optional deps like `datasets`. | |
| raise HTTPException( | |
| status_code=503, | |
| detail=( | |
| "DINO dependencies are not installed. " | |
| "Install extras with: pip install datasets scikit-learn" | |
| ), | |
| ) from e | |
| try: | |
| dino = load_dino(device) | |
| mold_embs = build_embeddings(dino, transform, device) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Failed to initialize DINO reference embeddings: {e}", | |
| ) from e | |
| # ------------------ | |
| # API endpoints | |
| # ------------------ | |
| async def predict_v1(file: UploadFile): | |
| img = Image.open(io.BytesIO(await file.read())).convert("RGB") | |
| img_t = transform(img).to(device) | |
| with torch.no_grad(): | |
| out = model(img_t.unsqueeze(0)) | |
| cp = torch.softmax(out["class"], 1)[0] | |
| bp = torch.softmax(out["bio"], 1)[0] | |
| mold_p = cp[mold_idx].item() | |
| bio_p = bp[1].item() | |
| decision = final_decision(mold_p, bio_p) | |
| return { | |
| "decision": decision, | |
| "mold_probability": round(mold_p, 3), | |
| "biological_probability": round(bio_p, 3), | |
| } | |
| async def predict_v2(file: UploadFile): | |
| ensure_dino() | |
| # Import similarity lazily (only needed for v2) | |
| from dino import similarity | |
| img = Image.open(io.BytesIO(await file.read())).convert("RGB") | |
| img_t = transform(img).to(device) | |
| with torch.no_grad(): | |
| out = model(img_t.unsqueeze(0)) | |
| cp = torch.softmax(out["class"], 1)[0] | |
| bp = torch.softmax(out["bio"], 1)[0] | |
| mold_p = cp[mold_idx].item() | |
| bio_p = bp[1].item() | |
| mean_p, std_p = mc_uncertainty(model, img_t, mold_idx) | |
| patch_ratio = patch_consistency( | |
| model, img, transform, mold_idx, device | |
| ) | |
| dino_sim = similarity( | |
| dino, mold_embs, img, transform, device | |
| ) | |
| decision = final_decision_v2( | |
| mold_p, bio_p, std_p, patch_ratio, dino_sim | |
| ) | |
| return { | |
| "decision": decision, | |
| "model_outputs": { | |
| "mold_probability": round(mold_p, 3), | |
| "biological_probability": round(bio_p, 3), | |
| }, | |
| "confidence_checks": { | |
| "uncertainty": round(std_p, 3), | |
| "patch_ratio": round(patch_ratio, 3), | |
| "dino_similarity": round(dino_sim, 3), | |
| }, | |
| } | |
| async def explain_gradcam(file: UploadFile): | |
| img = Image.open(io.BytesIO(await file.read())).convert("RGB") | |
| img_t = transform(img).to(device) | |
| cam = gradcam.generate(img_t, mold_idx) | |
| return {"gradcam": cam.tolist()} | |
| async def predict_resnet(file: UploadFile): | |
| img = Image.open(io.BytesIO(await file.read())).convert("RGB") | |
| img_t = transform(img).to(device) | |
| with torch.no_grad(): | |
| out = resnet_model(img_t.unsqueeze(0)) | |
| cp = torch.softmax(out["class"], 1)[0] | |
| bp = torch.softmax(out["bio"], 1)[0] | |
| mold_p = cp[resnet_mold_idx].item() | |
| bio_p = bp[1].item() | |
| decision = final_decision(mold_p, bio_p) | |
| return { | |
| "decision": decision, | |
| "mold_probability": round(mold_p, 3), | |
| "biological_probability": round(bio_p, 3), | |
| } | |