mediscan-api / app.py
Mittalyash's picture
Upload folder using huggingface_hub
9916246 verified
import os
import io
import torch
import cv2
import numpy as np
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
# Import models
from classification.model import CliniScanClassifier
from detection.model import get_detection_model
import albumentations as A
from albumentations.pytorch import ToTensorV2
app = FastAPI(title="MediScanAI API")
# Allow CORS for Vercel
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, restrict to the Vercel URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
device = torch.device('cpu') # Hugging Face Free Tier uses CPU
# Global model variables
class_model = None
detect_model = None
# Class names
CLASS_NAMES_CLASSIFY = [
"Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
"Consolidation", "ILD", "Infiltration", "Lung Opacity",
"Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
"Pneumothorax", "Pulmonary fibrosis", "No finding"
]
CLASS_NAMES_DETECT = [
"Background", "Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
"Consolidation", "ILD", "Infiltration", "Lung Opacity",
"Nodule/Mass", "Other lesion", "Pleural effusion", "Pleural thickening",
"Pneumothorax", "Pulmonary fibrosis", "No finding"
]
def get_inference_transforms():
return A.Compose([
A.Resize(256, 256),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
@app.on_event("startup")
async def load_models():
global class_model, detect_model
try:
print("Loading Classification Model...")
class_model = CliniScanClassifier(15)
class_model.load_state_dict(torch.load('models/best_resnet_classification.pth', map_location=device))
class_model.to(device)
class_model.eval()
print("Classification Model Loaded.")
print("Loading Detection Model...")
detect_model = get_detection_model(16)
detect_model.load_state_dict(torch.load('models/best_faster_rcnn_detection.pth', map_location=device))
detect_model.to(device)
detect_model.eval()
print("Detection Model Loaded.")
except Exception as e:
print(f"Error loading models: {e}")
@app.get("/")
def read_root():
return {"status": "MediScanAI API is running"}
@app.post("/predict")
async def predict(file: UploadFile = File(...), mode: str = Form("classify")):
# Read image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert('RGB')
image_np = np.array(image)
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) # OpenCV format
if mode == "classify":
transform = get_inference_transforms()
transformed = transform(image=image_bgr)
image_tensor = transformed['image'].unsqueeze(0).to(device)
with torch.no_grad():
logits = class_model(image_tensor)
probs = torch.sigmoid(logits).cpu().numpy()[0]
results = []
for i, prob in enumerate(probs):
if prob > 0.1: # Lower threshold for frontend to filter
# Convert names to match frontend (e.g., "Pleural effusion" -> "Effusion")
# We will map standard names to match the frontend expectations
name = CLASS_NAMES_CLASSIFY[i]
if name == "Pleural effusion": name = "Effusion"
elif name == "Nodule/Mass": name = "Mass"
elif name == "Pleural thickening": name = "Pleural_Thickening"
results.append({"name": name, "score": float(prob)})
results.sort(key=lambda x: x['score'], reverse=True)
return {"mode": "classify", "predictions": results}
elif mode == "detect":
orig_h, orig_w = image_np.shape[:2]
image_resized = cv2.resize(image_np, (256, 256))
image_tensor = torch.as_tensor(image_resized, dtype=torch.float32).permute(2, 0, 1) / 255.0
image_tensor = image_tensor.unsqueeze(0).to(device)
with torch.no_grad():
prediction = detect_model(image_tensor)[0]
boxes = prediction['boxes'].cpu().numpy()
labels = prediction['labels'].cpu().numpy()
scores = prediction['scores'].cpu().numpy()
# Scale boxes back to 0-100 percentage for the frontend SVG viewBox
boxes_pct = boxes.copy()
boxes_pct[:, [0, 2]] = (boxes_pct[:, [0, 2]] / 256.0) * 100
boxes_pct[:, [1, 3]] = (boxes_pct[:, [1, 3]] / 256.0) * 100
results = []
for i in range(len(scores)):
if scores[i] > 0.1: # Return lower bounds, let frontend threshold filter
name = CLASS_NAMES_DETECT[labels[i]]
if name == "Pleural effusion": name = "Effusion"
elif name == "Nodule/Mass": name = "Mass"
elif name == "Pleural thickening": name = "Pleural_Thickening"
box = boxes_pct[i].tolist()
results.append({
"name": name,
"score": float(scores[i]),
"box": {
"xMin": box[0], "yMin": box[1], "xMax": box[2], "yMax": box[3]
}
})
# Deduplicate overlapping boxes for the same class (NMS simplified)
# Not strictly necessary if the model already runs NMS, but good for cleanup
results.sort(key=lambda x: x['score'], reverse=True)
return {"mode": "detect", "predictions": results}
return {"error": "Invalid mode"}