sagarmee commited on
Commit
a05959f
·
verified ·
1 Parent(s): b605c02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -12
app.py CHANGED
@@ -1,19 +1,55 @@
1
- from huggingface_hub import hf_hub_download
 
 
2
  from safetensors.torch import load_file
 
 
 
3
 
4
- MODEL_REPO = "itsomk/chexpert-densenet121"
5
- MODEL_FILE = "chexpert_pytorch.safetensors"
6
 
7
- # Download weights file
8
- model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
 
 
 
9
 
10
- # Load the safetensors weights
11
- state = load_file(model_path)
 
 
 
 
 
 
 
12
 
13
- # Build the DenseNet121 model and load weights
14
- import torchvision.models as models
15
- model = models.densenet121(pretrained=False)
16
- num_features = model.classifier.in_features
17
- model.classifier = torch.nn.Linear(num_features, 14)
 
 
 
18
  model.load_state_dict(state, strict=False)
19
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ import torch
3
+ from torchvision import models, transforms
4
  from safetensors.torch import load_file
5
+ from huggingface_hub import hf_hub_download
6
+ from PIL import Image
7
+ import io
8
 
9
+ app = FastAPI()
 
10
 
11
+ LABELS = [
12
+ "No Finding","Enlarged Cardiomediastinum","Cardiomegaly","Lung Opacity",
13
+ "Lung Lesion","Edema","Consolidation","Pneumonia","Atelectasis",
14
+ "Pneumothorax","Pleural Effusion","Pleural Other","Fracture","Support Devices"
15
+ ]
16
 
17
+ class DenseNet121_CheXpert(torch.nn.Module):
18
+ def __init__(self, num_labels=14):
19
+ super().__init__()
20
+ self.densenet = models.densenet121(pretrained=False)
21
+ self.densenet.classifier = torch.nn.Linear(
22
+ self.densenet.classifier.in_features, num_labels
23
+ )
24
+ def forward(self, x):
25
+ return self.densenet(x)
26
 
27
+ # download weights
28
+ model_path = hf_hub_download(
29
+ repo_id="itsomk/chexpert-densenet121",
30
+ filename="pytorch_model.safetensors"
31
+ )
32
+
33
+ model = DenseNet121_CheXpert()
34
+ state = load_file(model_path)
35
  model.load_state_dict(state, strict=False)
36
  model.eval()
37
+
38
+ preprocess = transforms.Compose([
39
+ transforms.Resize((224,224)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize(
42
+ mean=[0.485,0.456,0.406],
43
+ std=[0.229,0.224,0.225]
44
+ )
45
+ ])
46
+
47
+ @app.post("/predict")
48
+ async def predict(file: UploadFile = File(...)):
49
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB")
50
+ x = preprocess(img).unsqueeze(0)
51
+
52
+ with torch.no_grad():
53
+ probs = torch.sigmoid(model(x)).squeeze().tolist()
54
+
55
+ return dict(zip(LABELS, probs))