cliniscan-api / app.py
luckysoni10's picture
Fix: load efficientnet without pretrained hash
4a13811
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import io
import cv2
import base64
from torchvision import transforms
from torchvision.models import efficientnet_b0
from huggingface_hub import hf_hub_download
app = FastAPI(title="CliniScan API", version="1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Constants ──────────────────────────────────────────────
CLASS_NAMES = [
'Aortic enlargement', 'Atelectasis', 'Calcification',
'Cardiomegaly', 'Consolidation', 'ILD', 'Infiltration',
'Lung Opacity', 'Nodule/Mass', 'Other lesion',
'Pleural effusion', 'Pleural thickening',
'Pneumothorax', 'Pulmonary fibrosis'
]
NUM_CLASSES = 14
DEVICE = torch.device('cpu')
HF_REPO = "luckysoni10/cliniscan-weights"
# ── Load Models ────────────────────────────────────────────
print("Loading models...")
# Classification model
clf_model = efficientnet_b0(weights=None)
clf_model.classifier[1] = nn.Linear(
clf_model.classifier[1].in_features, NUM_CLASSES)
clf_path = hf_hub_download(repo_id=HF_REPO, filename="m3_efficientnet_adamw.pth")
state = torch.load(clf_path, map_location=DEVICE)
clf_model.load_state_dict(state)
clf_model.eval()
# Detection model
from ultralytics import YOLO
det_path = hf_hub_download(repo_id=HF_REPO, filename="best.pt")
det_model = YOLO(det_path)
print("βœ… Models loaded!")
# ── Transforms ─────────────────────────────────────────────
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
# ── Grad-CAM ───────────────────────────────────────────────
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.gradients = None
self.activations = None
target_layer.register_forward_hook(
lambda m,i,o: setattr(self, 'activations', o.detach()))
target_layer.register_full_backward_hook(
lambda m,i,o: setattr(self, 'gradients', o[0].detach()))
def generate(self, tensor, class_idx):
self.model.zero_grad()
out = self.model(tensor)
out[0, class_idx].backward()
weights = self.gradients[0].mean(dim=(1,2))
cam = sum(w * a for w, a in
zip(weights, self.activations[0]))
cam = torch.relu(cam)
cam = (cam - cam.min()) / (cam.max() + 1e-8)
return cam.numpy()
gradcam = GradCAM(clf_model, clf_model.features[-1])
# ── Helper ─────────────────────────────────────────────────
def read_image(file_bytes):
img = Image.open(io.BytesIO(file_bytes)).convert('RGB')
return img
def img_to_base64(img_array):
_, buf = cv2.imencode('.png', img_array)
return base64.b64encode(buf).decode('utf-8')
# ── Routes ─────────────────────────────────────────────────
@app.get("/")
def root():
return {"status": "CliniScan API is running βœ…"}
@app.get("/health")
def health():
return {"status": "ok", "models": "loaded"}
@app.post("/predict/classify")
async def classify(file: UploadFile = File(...), threshold: float = 0.3):
img = read_image(await file.read())
tensor = transform(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = clf_model(tensor)
probs = torch.sigmoid(output)[0].numpy()
results = []
for i, (name, prob) in enumerate(zip(CLASS_NAMES, probs)):
results.append({
"class_id": i,
"class_name": name,
"probability": round(float(prob), 4),
"detected": bool(prob >= threshold)
})
results.sort(key=lambda x: x['probability'], reverse=True)
detected = [r for r in results if r['detected']]
return JSONResponse({
"status": "success",
"detected": detected,
"all_probs": results,
"threshold": threshold
})
@app.post("/predict/detect")
async def detect(file: UploadFile = File(...), confidence: float = 0.25):
img_bytes = await file.read()
img = read_image(img_bytes)
img_cv = cv2.cvtColor(np.array(img.resize((224,224))), cv2.COLOR_RGB2BGR)
results = det_model.predict(img_cv, conf=confidence, verbose=False)
boxes = []
for box in results[0].boxes:
x1,y1,x2,y2 = map(int, box.xyxy[0])
cls_id = int(box.cls[0])
conf = float(box.conf[0])
COLORS = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),
(255,0,255),(0,255,255),(128,0,0),(0,128,0),
(0,0,128),(128,128,0),(128,0,128),(0,128,128),
(64,0,0),(0,64,0)]
color = COLORS[cls_id % len(COLORS)]
cv2.rectangle(img_cv, (x1,y1), (x2,y2), color, 2)
cv2.putText(img_cv, f"{CLASS_NAMES[cls_id]} {conf:.2f}",
(x1, max(y1-5,10)),
cv2.FONT_HERSHEY_SIMPLEX, 0.35, color, 1)
boxes.append({
"class_name": CLASS_NAMES[cls_id],
"confidence": round(conf, 3),
"bbox": [x1,y1,x2,y2]
})
return JSONResponse({
"status": "success",
"boxes": boxes,
"total_found": len(boxes),
"annotated_image": img_to_base64(img_cv)
})
@app.post("/predict/gradcam")
async def gradcam_endpoint(file: UploadFile = File(...)):
img_bytes = await file.read()
img = read_image(img_bytes)
orig = cv2.cvtColor(np.array(img.resize((224,224))), cv2.COLOR_RGB2BGR)
tensor = transform(img).unsqueeze(0).to(DEVICE)
tensor.requires_grad = True
with torch.no_grad():
probs = torch.sigmoid(clf_model(tensor))[0]
top_class = int(probs.argmax())
cam = gradcam.generate(tensor, top_class)
cam_resized = cv2.resize(cam, (224,224))
heatmap = cv2.applyColorMap(
(cam_resized*255).astype(np.uint8), cv2.COLORMAP_JET)
overlay = cv2.addWeighted(orig, 0.5, heatmap, 0.5, 0)
return JSONResponse({
"status": "success",
"predicted_class":CLASS_NAMES[top_class],
"confidence": round(float(probs[top_class]), 4),
"heatmap_image": img_to_base64(overlay),
"original_image": img_to_base64(orig)
})
@app.post("/predict/batch")
async def batch(files: list[UploadFile] = File(...), threshold: float = 0.3):
batch_results = []
for file in files:
img = read_image(await file.read())
tensor = transform(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
probs = torch.sigmoid(clf_model(tensor))[0].numpy()
detected = [
{"class_name": CLASS_NAMES[i], "probability": round(float(p),4)}
for i,p in enumerate(probs) if p >= threshold
]
batch_results.append({
"filename": file.filename,
"detected": detected,
"finding_count": len(detected)
})
return JSONResponse({"status":"success","results":batch_results})