import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' from fastapi import FastAPI, File, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import torch import torch.nn as nn import numpy as np from PIL import Image import io import cv2 import base64 from torchvision import transforms from torchvision.models import efficientnet_b0 from huggingface_hub import hf_hub_download app = FastAPI(title="CliniScan API", version="1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── Constants ────────────────────────────────────────────── CLASS_NAMES = [ 'Aortic enlargement', 'Atelectasis', 'Calcification', 'Cardiomegaly', 'Consolidation', 'ILD', 'Infiltration', 'Lung Opacity', 'Nodule/Mass', 'Other lesion', 'Pleural effusion', 'Pleural thickening', 'Pneumothorax', 'Pulmonary fibrosis' ] NUM_CLASSES = 14 DEVICE = torch.device('cpu') HF_REPO = "luckysoni10/cliniscan-weights" # ── Load Models ──────────────────────────────────────────── print("Loading models...") # Classification model clf_model = efficientnet_b0(weights=None) clf_model.classifier[1] = nn.Linear( clf_model.classifier[1].in_features, NUM_CLASSES) clf_path = hf_hub_download(repo_id=HF_REPO, filename="m3_efficientnet_adamw.pth") state = torch.load(clf_path, map_location=DEVICE) clf_model.load_state_dict(state) clf_model.eval() # Detection model from ultralytics import YOLO det_path = hf_hub_download(repo_id=HF_REPO, filename="best.pt") det_model = YOLO(det_path) print("✅ Models loaded!") # ── Transforms ───────────────────────────────────────────── transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ]) # ── Grad-CAM ─────────────────────────────────────────────── class GradCAM: def __init__(self, model, target_layer): self.model = model self.gradients = None self.activations = None target_layer.register_forward_hook( lambda m,i,o: setattr(self, 'activations', o.detach())) target_layer.register_full_backward_hook( lambda m,i,o: setattr(self, 'gradients', o[0].detach())) def generate(self, tensor, class_idx): self.model.zero_grad() out = self.model(tensor) out[0, class_idx].backward() weights = self.gradients[0].mean(dim=(1,2)) cam = sum(w * a for w, a in zip(weights, self.activations[0])) cam = torch.relu(cam) cam = (cam - cam.min()) / (cam.max() + 1e-8) return cam.numpy() gradcam = GradCAM(clf_model, clf_model.features[-1]) # ── Helper ───────────────────────────────────────────────── def read_image(file_bytes): img = Image.open(io.BytesIO(file_bytes)).convert('RGB') return img def img_to_base64(img_array): _, buf = cv2.imencode('.png', img_array) return base64.b64encode(buf).decode('utf-8') # ── Routes ───────────────────────────────────────────────── @app.get("/") def root(): return {"status": "CliniScan API is running ✅"} @app.get("/health") def health(): return {"status": "ok", "models": "loaded"} @app.post("/predict/classify") async def classify(file: UploadFile = File(...), threshold: float = 0.3): img = read_image(await file.read()) tensor = transform(img).unsqueeze(0).to(DEVICE) with torch.no_grad(): output = clf_model(tensor) probs = torch.sigmoid(output)[0].numpy() results = [] for i, (name, prob) in enumerate(zip(CLASS_NAMES, probs)): results.append({ "class_id": i, "class_name": name, "probability": round(float(prob), 4), "detected": bool(prob >= threshold) }) results.sort(key=lambda x: x['probability'], reverse=True) detected = [r for r in results if r['detected']] return JSONResponse({ "status": "success", "detected": detected, "all_probs": results, "threshold": threshold }) @app.post("/predict/detect") async def detect(file: UploadFile = File(...), confidence: float = 0.25): img_bytes = await file.read() img = read_image(img_bytes) img_cv = cv2.cvtColor(np.array(img.resize((224,224))), cv2.COLOR_RGB2BGR) results = det_model.predict(img_cv, conf=confidence, verbose=False) boxes = [] for box in results[0].boxes: x1,y1,x2,y2 = map(int, box.xyxy[0]) cls_id = int(box.cls[0]) conf = float(box.conf[0]) COLORS = [(255,0,0),(0,255,0),(0,0,255),(255,255,0), (255,0,255),(0,255,255),(128,0,0),(0,128,0), (0,0,128),(128,128,0),(128,0,128),(0,128,128), (64,0,0),(0,64,0)] color = COLORS[cls_id % len(COLORS)] cv2.rectangle(img_cv, (x1,y1), (x2,y2), color, 2) cv2.putText(img_cv, f"{CLASS_NAMES[cls_id]} {conf:.2f}", (x1, max(y1-5,10)), cv2.FONT_HERSHEY_SIMPLEX, 0.35, color, 1) boxes.append({ "class_name": CLASS_NAMES[cls_id], "confidence": round(conf, 3), "bbox": [x1,y1,x2,y2] }) return JSONResponse({ "status": "success", "boxes": boxes, "total_found": len(boxes), "annotated_image": img_to_base64(img_cv) }) @app.post("/predict/gradcam") async def gradcam_endpoint(file: UploadFile = File(...)): img_bytes = await file.read() img = read_image(img_bytes) orig = cv2.cvtColor(np.array(img.resize((224,224))), cv2.COLOR_RGB2BGR) tensor = transform(img).unsqueeze(0).to(DEVICE) tensor.requires_grad = True with torch.no_grad(): probs = torch.sigmoid(clf_model(tensor))[0] top_class = int(probs.argmax()) cam = gradcam.generate(tensor, top_class) cam_resized = cv2.resize(cam, (224,224)) heatmap = cv2.applyColorMap( (cam_resized*255).astype(np.uint8), cv2.COLORMAP_JET) overlay = cv2.addWeighted(orig, 0.5, heatmap, 0.5, 0) return JSONResponse({ "status": "success", "predicted_class":CLASS_NAMES[top_class], "confidence": round(float(probs[top_class]), 4), "heatmap_image": img_to_base64(overlay), "original_image": img_to_base64(orig) }) @app.post("/predict/batch") async def batch(files: list[UploadFile] = File(...), threshold: float = 0.3): batch_results = [] for file in files: img = read_image(await file.read()) tensor = transform(img).unsqueeze(0).to(DEVICE) with torch.no_grad(): probs = torch.sigmoid(clf_model(tensor))[0].numpy() detected = [ {"class_name": CLASS_NAMES[i], "probability": round(float(p),4)} for i,p in enumerate(probs) if p >= threshold ] batch_results.append({ "filename": file.filename, "detected": detected, "finding_count": len(detected) }) return JSONResponse({"status":"success","results":batch_results})