Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File | |
| import torch | |
| from torchvision import models, transforms | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| import io | |
| app = FastAPI() | |
| LABELS = [ | |
| "No Finding","Enlarged Cardiomediastinum","Cardiomegaly","Lung Opacity", | |
| "Lung Lesion","Edema","Consolidation","Pneumonia","Atelectasis", | |
| "Pneumothorax","Pleural Effusion","Pleural Other","Fracture","Support Devices" | |
| ] | |
| class DenseNet121_CheXpert(torch.nn.Module): | |
| def __init__(self, num_labels=14): | |
| super().__init__() | |
| self.densenet = models.densenet121(pretrained=False) | |
| self.densenet.classifier = torch.nn.Linear( | |
| self.densenet.classifier.in_features, num_labels | |
| ) | |
| def forward(self, x): | |
| return self.densenet(x) | |
| # download weights | |
| model_path = hf_hub_download( | |
| repo_id="itsomk/chexpert-densenet121", | |
| filename="pytorch_model.safetensors" | |
| ) | |
| model = DenseNet121_CheXpert() | |
| state = load_file(model_path) | |
| model.load_state_dict(state, strict=False) | |
| model.eval() | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224,224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485,0.456,0.406], | |
| std=[0.229,0.224,0.225] | |
| ) | |
| ]) | |
| async def predict(file: UploadFile = File(...)): | |
| img = Image.open(io.BytesIO(await file.read())).convert("RGB") | |
| x = preprocess(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| probs = torch.sigmoid(model(x)).squeeze().tolist() | |
| return dict(zip(LABELS, probs)) | |