intellicure-api / api /pneumonia.py
arghyaxcodes's picture
feat(frontend): update next dependency version to 15.3.6 to fix CVE-2025-66478
2eee02a verified
# app/api/pneumonia.py
from fastapi import APIRouter, UploadFile, File, HTTPException
from core.utils import preprocess_image_pil, generate_gemini_insights
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
import os
from typing import Tuple
from PIL import Image
from huggingface_hub import hf_hub_download
router = APIRouter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = ["NORMAL", "PNEUMONIA"]
transform = T.Compose(
[
T.Resize((256, 256)),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
model_path = hf_hub_download(
repo_id="arghyaxcodes/test-store",
filename="models/pneumonia.pt",
repo_type="dataset"
)
def load_model(model_path: str) -> nn.Module:
model = models.vgg19(pretrained=False)
in_features = (
model.classifier[0].in_features
if isinstance(model.classifier[0], nn.Linear)
else 25088
)
model.classifier = nn.Sequential(
nn.Linear(in_features, 4096), # type: ignore
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(4096, 2),
nn.LogSoftmax(dim=1),
)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device).eval()
return model
model = load_model(model_path)
def predict(image: Image.Image) -> Tuple[str, float]:
tensor_image: torch.Tensor = transform(image) # type: ignore
input_tensor = tensor_image.unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
probs = torch.exp(output)
confidence, predicted_idx = torch.max(probs, dim=1)
return class_names[predicted_idx.item()], float(confidence.item()) # type: ignore
@router.post(
"/predict",
name="shwaas_veda",
description="Predict pneumonia from chest X-ray images",
)
async def classify_image(file: UploadFile = File(...)):
try:
contents = await file.read()
image = preprocess_image_pil(contents)
label, confidence = predict(image)
analysis = await generate_gemini_insights(
label, {label: confidence}, mode="pneumonia"
)
return {
"prediction": label,
"confidence": round(confidence, 4),
"gemini_analysis": analysis,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")