Spaces:
Sleeping
Sleeping
File size: 1,573 Bytes
a05959f 58d14cc a05959f c19c9f3 a05959f c19c9f3 a05959f baa4bc1 a05959f baa4bc1 a05959f 58d14cc c19c9f3 a05959f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | from fastapi import FastAPI, UploadFile, File
import torch
from torchvision import models, transforms
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from PIL import Image
import io
app = FastAPI()
LABELS = [
"No Finding","Enlarged Cardiomediastinum","Cardiomegaly","Lung Opacity",
"Lung Lesion","Edema","Consolidation","Pneumonia","Atelectasis",
"Pneumothorax","Pleural Effusion","Pleural Other","Fracture","Support Devices"
]
class DenseNet121_CheXpert(torch.nn.Module):
def __init__(self, num_labels=14):
super().__init__()
self.densenet = models.densenet121(pretrained=False)
self.densenet.classifier = torch.nn.Linear(
self.densenet.classifier.in_features, num_labels
)
def forward(self, x):
return self.densenet(x)
# download weights
model_path = hf_hub_download(
repo_id="itsomk/chexpert-densenet121",
filename="pytorch_model.safetensors"
)
model = DenseNet121_CheXpert()
state = load_file(model_path)
model.load_state_dict(state, strict=False)
model.eval()
preprocess = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]
)
])
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
x = preprocess(img).unsqueeze(0)
with torch.no_grad():
probs = torch.sigmoid(model(x)).squeeze().tolist()
return dict(zip(LABELS, probs))
|