from fastapi import FastAPI, UploadFile, File import torch from torchvision import models, transforms from safetensors.torch import load_file from huggingface_hub import hf_hub_download from PIL import Image import io app = FastAPI() LABELS = [ "No Finding","Enlarged Cardiomediastinum","Cardiomegaly","Lung Opacity", "Lung Lesion","Edema","Consolidation","Pneumonia","Atelectasis", "Pneumothorax","Pleural Effusion","Pleural Other","Fracture","Support Devices" ] class DenseNet121_CheXpert(torch.nn.Module): def __init__(self, num_labels=14): super().__init__() self.densenet = models.densenet121(pretrained=False) self.densenet.classifier = torch.nn.Linear( self.densenet.classifier.in_features, num_labels ) def forward(self, x): return self.densenet(x) # download weights model_path = hf_hub_download( repo_id="itsomk/chexpert-densenet121", filename="pytorch_model.safetensors" ) model = DenseNet121_CheXpert() state = load_file(model_path) model.load_state_dict(state, strict=False) model.eval() preprocess = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225] ) ]) @app.post("/predict") async def predict(file: UploadFile = File(...)): img = Image.open(io.BytesIO(await file.read())).convert("RGB") x = preprocess(img).unsqueeze(0) with torch.no_grad(): probs = torch.sigmoid(model(x)).squeeze().tolist() return dict(zip(LABELS, probs))