Spaces:
Running
Running
| import os | |
| os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import cv2 | |
| import base64 | |
| from torchvision import transforms | |
| from torchvision.models import efficientnet_b0 | |
| from huggingface_hub import hf_hub_download | |
| app = FastAPI(title="CliniScan API", version="1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Constants ββββββββββββββββββββββββββββββββββββββββββββββ | |
| CLASS_NAMES = [ | |
| 'Aortic enlargement', 'Atelectasis', 'Calcification', | |
| 'Cardiomegaly', 'Consolidation', 'ILD', 'Infiltration', | |
| 'Lung Opacity', 'Nodule/Mass', 'Other lesion', | |
| 'Pleural effusion', 'Pleural thickening', | |
| 'Pneumothorax', 'Pulmonary fibrosis' | |
| ] | |
| NUM_CLASSES = 14 | |
| DEVICE = torch.device('cpu') | |
| HF_REPO = "luckysoni10/cliniscan-weights" | |
| # ββ Load Models ββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Loading models...") | |
| # Classification model | |
| clf_model = efficientnet_b0(weights=None) | |
| clf_model.classifier[1] = nn.Linear( | |
| clf_model.classifier[1].in_features, NUM_CLASSES) | |
| clf_path = hf_hub_download(repo_id=HF_REPO, filename="m3_efficientnet_adamw.pth") | |
| state = torch.load(clf_path, map_location=DEVICE) | |
| clf_model.load_state_dict(state) | |
| clf_model.eval() | |
| # Detection model | |
| from ultralytics import YOLO | |
| det_path = hf_hub_download(repo_id=HF_REPO, filename="best.pt") | |
| det_model = YOLO(det_path) | |
| print("β Models loaded!") | |
| # ββ Transforms βββββββββββββββββββββββββββββββββββββββββββββ | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) | |
| ]) | |
| # ββ Grad-CAM βββββββββββββββββββββββββββββββββββββββββββββββ | |
| class GradCAM: | |
| def __init__(self, model, target_layer): | |
| self.model = model | |
| self.gradients = None | |
| self.activations = None | |
| target_layer.register_forward_hook( | |
| lambda m,i,o: setattr(self, 'activations', o.detach())) | |
| target_layer.register_full_backward_hook( | |
| lambda m,i,o: setattr(self, 'gradients', o[0].detach())) | |
| def generate(self, tensor, class_idx): | |
| self.model.zero_grad() | |
| out = self.model(tensor) | |
| out[0, class_idx].backward() | |
| weights = self.gradients[0].mean(dim=(1,2)) | |
| cam = sum(w * a for w, a in | |
| zip(weights, self.activations[0])) | |
| cam = torch.relu(cam) | |
| cam = (cam - cam.min()) / (cam.max() + 1e-8) | |
| return cam.numpy() | |
| gradcam = GradCAM(clf_model, clf_model.features[-1]) | |
| # ββ Helper βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def read_image(file_bytes): | |
| img = Image.open(io.BytesIO(file_bytes)).convert('RGB') | |
| return img | |
| def img_to_base64(img_array): | |
| _, buf = cv2.imencode('.png', img_array) | |
| return base64.b64encode(buf).decode('utf-8') | |
| # ββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return {"status": "CliniScan API is running β "} | |
| def health(): | |
| return {"status": "ok", "models": "loaded"} | |
| async def classify(file: UploadFile = File(...), threshold: float = 0.3): | |
| img = read_image(await file.read()) | |
| tensor = transform(img).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output = clf_model(tensor) | |
| probs = torch.sigmoid(output)[0].numpy() | |
| results = [] | |
| for i, (name, prob) in enumerate(zip(CLASS_NAMES, probs)): | |
| results.append({ | |
| "class_id": i, | |
| "class_name": name, | |
| "probability": round(float(prob), 4), | |
| "detected": bool(prob >= threshold) | |
| }) | |
| results.sort(key=lambda x: x['probability'], reverse=True) | |
| detected = [r for r in results if r['detected']] | |
| return JSONResponse({ | |
| "status": "success", | |
| "detected": detected, | |
| "all_probs": results, | |
| "threshold": threshold | |
| }) | |
| async def detect(file: UploadFile = File(...), confidence: float = 0.25): | |
| img_bytes = await file.read() | |
| img = read_image(img_bytes) | |
| img_cv = cv2.cvtColor(np.array(img.resize((224,224))), cv2.COLOR_RGB2BGR) | |
| results = det_model.predict(img_cv, conf=confidence, verbose=False) | |
| boxes = [] | |
| for box in results[0].boxes: | |
| x1,y1,x2,y2 = map(int, box.xyxy[0]) | |
| cls_id = int(box.cls[0]) | |
| conf = float(box.conf[0]) | |
| COLORS = [(255,0,0),(0,255,0),(0,0,255),(255,255,0), | |
| (255,0,255),(0,255,255),(128,0,0),(0,128,0), | |
| (0,0,128),(128,128,0),(128,0,128),(0,128,128), | |
| (64,0,0),(0,64,0)] | |
| color = COLORS[cls_id % len(COLORS)] | |
| cv2.rectangle(img_cv, (x1,y1), (x2,y2), color, 2) | |
| cv2.putText(img_cv, f"{CLASS_NAMES[cls_id]} {conf:.2f}", | |
| (x1, max(y1-5,10)), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.35, color, 1) | |
| boxes.append({ | |
| "class_name": CLASS_NAMES[cls_id], | |
| "confidence": round(conf, 3), | |
| "bbox": [x1,y1,x2,y2] | |
| }) | |
| return JSONResponse({ | |
| "status": "success", | |
| "boxes": boxes, | |
| "total_found": len(boxes), | |
| "annotated_image": img_to_base64(img_cv) | |
| }) | |
| async def gradcam_endpoint(file: UploadFile = File(...)): | |
| img_bytes = await file.read() | |
| img = read_image(img_bytes) | |
| orig = cv2.cvtColor(np.array(img.resize((224,224))), cv2.COLOR_RGB2BGR) | |
| tensor = transform(img).unsqueeze(0).to(DEVICE) | |
| tensor.requires_grad = True | |
| with torch.no_grad(): | |
| probs = torch.sigmoid(clf_model(tensor))[0] | |
| top_class = int(probs.argmax()) | |
| cam = gradcam.generate(tensor, top_class) | |
| cam_resized = cv2.resize(cam, (224,224)) | |
| heatmap = cv2.applyColorMap( | |
| (cam_resized*255).astype(np.uint8), cv2.COLORMAP_JET) | |
| overlay = cv2.addWeighted(orig, 0.5, heatmap, 0.5, 0) | |
| return JSONResponse({ | |
| "status": "success", | |
| "predicted_class":CLASS_NAMES[top_class], | |
| "confidence": round(float(probs[top_class]), 4), | |
| "heatmap_image": img_to_base64(overlay), | |
| "original_image": img_to_base64(orig) | |
| }) | |
| async def batch(files: list[UploadFile] = File(...), threshold: float = 0.3): | |
| batch_results = [] | |
| for file in files: | |
| img = read_image(await file.read()) | |
| tensor = transform(img).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| probs = torch.sigmoid(clf_model(tensor))[0].numpy() | |
| detected = [ | |
| {"class_name": CLASS_NAMES[i], "probability": round(float(p),4)} | |
| for i,p in enumerate(probs) if p >= threshold | |
| ] | |
| batch_results.append({ | |
| "filename": file.filename, | |
| "detected": detected, | |
| "finding_count": len(detected) | |
| }) | |
| return JSONResponse({"status":"success","results":batch_results}) |