chexpert-api / app.py
sagarmee's picture
Update app.py
a05959f verified
from fastapi import FastAPI, UploadFile, File
import torch
from torchvision import models, transforms
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from PIL import Image
import io
app = FastAPI()
LABELS = [
"No Finding","Enlarged Cardiomediastinum","Cardiomegaly","Lung Opacity",
"Lung Lesion","Edema","Consolidation","Pneumonia","Atelectasis",
"Pneumothorax","Pleural Effusion","Pleural Other","Fracture","Support Devices"
]
class DenseNet121_CheXpert(torch.nn.Module):
def __init__(self, num_labels=14):
super().__init__()
self.densenet = models.densenet121(pretrained=False)
self.densenet.classifier = torch.nn.Linear(
self.densenet.classifier.in_features, num_labels
)
def forward(self, x):
return self.densenet(x)
# download weights
model_path = hf_hub_download(
repo_id="itsomk/chexpert-densenet121",
filename="pytorch_model.safetensors"
)
model = DenseNet121_CheXpert()
state = load_file(model_path)
model.load_state_dict(state, strict=False)
model.eval()
preprocess = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]
)
])
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
x = preprocess(img).unsqueeze(0)
with torch.no_grad():
probs = torch.sigmoid(model(x)).squeeze().tolist()
return dict(zip(LABELS, probs))