Spaces:
Sleeping
Sleeping
File size: 6,124 Bytes
7a5f7fb 8cc2137 511fc83 8cc2137 7a5f7fb 4bb02cf 511fc83 7a5f7fb 8cc2137 2b8b06c 8cc2137 511fc83 8cc2137 c46050c bef3d34 6d5d66b bef3d34 c46050c bef3d34 c46050c bef3d34 c46050c bef3d34 4bb02cf 7d6580c 4bb02cf 7d6580c 511fc83 8cc2137 4bb02cf 511fc83 4bb02cf 8cc2137 4bb02cf 8cc2137 4bb02cf 8cc2137 4bb02cf 7a5f7fb 511fc83 4bb02cf 7a5f7fb 4bb02cf 7a5f7fb 4bb02cf 511fc83 2b8b06c 511fc83 4bb02cf 511fc83 4bb02cf 7a5f7fb 4bb02cf 511fc83 4bb02cf 511fc83 4bb02cf 511fc83 4bb02cf 511fc83 8cc2137 511fc83 4bb02cf 511fc83 4bb02cf 8cc2137 4bb02cf 511fc83 bef3d34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | from fastapi import FastAPI, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch, io
from torchvision import transforms
from model import MultiTaskResNet50, MultiTaskConvNeXt, find_last_conv2d
from decision import final_decision
from advanced_decision import (
mc_uncertainty,
patch_consistency,
final_decision_v2
)
from gradcam import GradCAM
from typing import Optional
app = FastAPI(title="Mold Detection API (ResNet + ConvNeXt)")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
device = "cuda" if torch.cuda.is_available() else "cpu"
# ------------------
# Load baseline model (ResNet)
# ------------------
# ------------------
# Load baseline model (ResNet)
# ------------------
resnet_ckpt = torch.load(
"resnet50_multitask_mold.pth",
map_location=device
)
# Handle different checkpoint formats
if isinstance(resnet_ckpt, dict) and "model" in resnet_ckpt:
resnet_state = resnet_ckpt["model"]
resnet_classes = resnet_ckpt.get("classes", [])
else:
resnet_state = resnet_ckpt
resnet_classes = []
resnet_num_classes = len(resnet_classes) if resnet_classes else 9
resnet_mold_idx = (
resnet_classes.index("mold")
if resnet_classes and "mold" in resnet_classes
else 4
)
resnet_model = MultiTaskResNet50(resnet_num_classes).to(device)
resnet_model.load_state_dict(resnet_state)
resnet_model.eval()
# ------------------
# Load main model (ConvNeXt)
# ------------------
# Expecting checkpoint with keys:
# - "model": state_dict
# - "classes": list of class names (length N, mold at some index)
ckpt = torch.load("best_convnext_multitask.pth", map_location=device)
classes = ckpt.get("classes") or []
num_classes = len(classes) if classes else 9
mold_idx = classes.index("mold") if classes else 4
model = MultiTaskConvNeXt(num_classes).to(device)
model.load_state_dict(ckpt["model"])
model.eval()
# ------------------
# Transforms
# ------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]
)
])
# ------------------
# Grad-CAM target layer (computed, not stored in model state_dict)
# ------------------
target_layer = find_last_conv2d(model.backbone)
gradcam = GradCAM(model, target_layer) if target_layer is not None else None
# ------------------
# DINO (lazy loaded)
# ------------------
dino: Optional[object] = None
mold_embs = None
def ensure_dino():
global dino, mold_embs
if dino is None:
try:
from dino import load_dino, build_embeddings
except ModuleNotFoundError as e:
# Local/dev env might not have optional deps like `datasets`.
raise HTTPException(
status_code=503,
detail=(
"DINO dependencies are not installed. "
"Install extras with: pip install datasets scikit-learn"
),
) from e
try:
dino = load_dino(device)
mold_embs = build_embeddings(dino, transform, device)
except Exception as e:
raise HTTPException(
status_code=503,
detail=f"Failed to initialize DINO reference embeddings: {e}",
) from e
# ------------------
# API endpoints
# ------------------
@app.post("/predict/v1")
async def predict_v1(file: UploadFile):
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
img_t = transform(img).to(device)
with torch.no_grad():
out = model(img_t.unsqueeze(0))
cp = torch.softmax(out["class"], 1)[0]
bp = torch.softmax(out["bio"], 1)[0]
mold_p = cp[mold_idx].item()
bio_p = bp[1].item()
decision = final_decision(mold_p, bio_p)
return {
"decision": decision,
"mold_probability": round(mold_p, 3),
"biological_probability": round(bio_p, 3),
}
@app.post("/predict/v2")
async def predict_v2(file: UploadFile):
ensure_dino()
# Import similarity lazily (only needed for v2)
from dino import similarity
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
img_t = transform(img).to(device)
with torch.no_grad():
out = model(img_t.unsqueeze(0))
cp = torch.softmax(out["class"], 1)[0]
bp = torch.softmax(out["bio"], 1)[0]
mold_p = cp[mold_idx].item()
bio_p = bp[1].item()
mean_p, std_p = mc_uncertainty(model, img_t, mold_idx)
patch_ratio = patch_consistency(
model, img, transform, mold_idx, device
)
dino_sim = similarity(
dino, mold_embs, img, transform, device
)
decision = final_decision_v2(
mold_p, bio_p, std_p, patch_ratio, dino_sim
)
return {
"decision": decision,
"model_outputs": {
"mold_probability": round(mold_p, 3),
"biological_probability": round(bio_p, 3),
},
"confidence_checks": {
"uncertainty": round(std_p, 3),
"patch_ratio": round(patch_ratio, 3),
"dino_similarity": round(dino_sim, 3),
},
}
@app.post("/explain/gradcam")
async def explain_gradcam(file: UploadFile):
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
img_t = transform(img).to(device)
cam = gradcam.generate(img_t, mold_idx)
return {"gradcam": cam.tolist()}
@app.post("/predict/resnet")
async def predict_resnet(file: UploadFile):
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
img_t = transform(img).to(device)
with torch.no_grad():
out = resnet_model(img_t.unsqueeze(0))
cp = torch.softmax(out["class"], 1)[0]
bp = torch.softmax(out["bio"], 1)[0]
mold_p = cp[resnet_mold_idx].item()
bio_p = bp[1].item()
decision = final_decision(mold_p, bio_p)
return {
"decision": decision,
"mold_probability": round(mold_p, 3),
"biological_probability": round(bio_p, 3),
}
|