Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from fastapi import FastAPI, File, UploadFile, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| # Import models | |
| from classification.model import CliniScanClassifier | |
| from detection.model import get_detection_model | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| app = FastAPI(title="MediScanAI API") | |
| # Allow CORS for Vercel | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, restrict to the Vercel URL | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| device = torch.device('cpu') # Hugging Face Free Tier uses CPU | |
| # Global model variables | |
| class_model = None | |
| detect_model = None | |
| # Class names | |
| CLASS_NAMES_CLASSIFY = [ | |
| "Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly", | |
| "Consolidation", "ILD", "Infiltration", "Lung Opacity", | |
| "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening", | |
| "Pneumothorax", "Pulmonary fibrosis", "No finding" | |
| ] | |
| CLASS_NAMES_DETECT = [ | |
| "Background", "Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly", | |
| "Consolidation", "ILD", "Infiltration", "Lung Opacity", | |
| "Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening", | |
| "Pneumothorax", "Pulmonary fibrosis", "No finding" | |
| ] | |
| def get_inference_transforms(): | |
| return A.Compose([ | |
| A.Resize(256, 256), | |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2() | |
| ]) | |
| async def load_models(): | |
| global class_model, detect_model | |
| try: | |
| print("Loading Classification Model...") | |
| class_model = CliniScanClassifier(15) | |
| class_model.load_state_dict(torch.load('models/best_resnet_classification.pth', map_location=device)) | |
| class_model.to(device) | |
| class_model.eval() | |
| print("Classification Model Loaded.") | |
| print("Loading Detection Model...") | |
| detect_model = get_detection_model(16) | |
| detect_model.load_state_dict(torch.load('models/best_faster_rcnn_detection.pth', map_location=device)) | |
| detect_model.to(device) | |
| detect_model.eval() | |
| print("Detection Model Loaded.") | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| def read_root(): | |
| return {"status": "MediScanAI API is running"} | |
| async def predict(file: UploadFile = File(...), mode: str = Form("classify")): | |
| # Read image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert('RGB') | |
| image_np = np.array(image) | |
| image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) # OpenCV format | |
| if mode == "classify": | |
| transform = get_inference_transforms() | |
| transformed = transform(image=image_bgr) | |
| image_tensor = transformed['image'].unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = class_model(image_tensor) | |
| probs = torch.sigmoid(logits).cpu().numpy()[0] | |
| results = [] | |
| for i, prob in enumerate(probs): | |
| if prob > 0.1: # Lower threshold for frontend to filter | |
| # Convert names to match frontend (e.g., "Pleural effusion" -> "Effusion") | |
| # We will map standard names to match the frontend expectations | |
| name = CLASS_NAMES_CLASSIFY[i] | |
| if name == "Pleural effusion": name = "Effusion" | |
| elif name == "Nodule/Mass": name = "Mass" | |
| elif name == "Pleural thickening": name = "Pleural_Thickening" | |
| results.append({"name": name, "score": float(prob)}) | |
| results.sort(key=lambda x: x['score'], reverse=True) | |
| return {"mode": "classify", "predictions": results} | |
| elif mode == "detect": | |
| orig_h, orig_w = image_np.shape[:2] | |
| image_resized = cv2.resize(image_np, (256, 256)) | |
| image_tensor = torch.as_tensor(image_resized, dtype=torch.float32).permute(2, 0, 1) / 255.0 | |
| image_tensor = image_tensor.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| prediction = detect_model(image_tensor)[0] | |
| boxes = prediction['boxes'].cpu().numpy() | |
| labels = prediction['labels'].cpu().numpy() | |
| scores = prediction['scores'].cpu().numpy() | |
| # Scale boxes back to 0-100 percentage for the frontend SVG viewBox | |
| boxes_pct = boxes.copy() | |
| boxes_pct[:, [0, 2]] = (boxes_pct[:, [0, 2]] / 256.0) * 100 | |
| boxes_pct[:, [1, 3]] = (boxes_pct[:, [1, 3]] / 256.0) * 100 | |
| results = [] | |
| for i in range(len(scores)): | |
| if scores[i] > 0.1: # Return lower bounds, let frontend threshold filter | |
| name = CLASS_NAMES_DETECT[labels[i]] | |
| if name == "Pleural effusion": name = "Effusion" | |
| elif name == "Nodule/Mass": name = "Mass" | |
| elif name == "Pleural thickening": name = "Pleural_Thickening" | |
| box = boxes_pct[i].tolist() | |
| results.append({ | |
| "name": name, | |
| "score": float(scores[i]), | |
| "box": { | |
| "xMin": box[0], "yMin": box[1], "xMax": box[2], "yMax": box[3] | |
| } | |
| }) | |
| # Deduplicate overlapping boxes for the same class (NMS simplified) | |
| # Not strictly necessary if the model already runs NMS, but good for cleanup | |
| results.sort(key=lambda x: x['score'], reverse=True) | |
| return {"mode": "detect", "predictions": results} | |
| return {"error": "Invalid mode"} | |