sagarmee commited on
Commit
58d14cc
·
verified ·
1 Parent(s): 76165b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -57
app.py CHANGED
@@ -1,65 +1,19 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- from PIL import Image
3
- import torch
4
- import torchvision.transforms as transforms
5
- import torchvision.models as models
6
  from huggingface_hub import hf_hub_download
7
- import io
8
-
9
- app = FastAPI()
10
 
11
  MODEL_REPO = "itsomk/chexpert-densenet121"
12
- MODEL_FILE = "chexpert_densenet121.pth"
13
 
 
14
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
15
 
16
- model = models.densenet121(pretrained=False)
17
- model.classifier = torch.nn.Linear(model.classifier.in_features, 14)
18
-
19
- checkpoint = torch.load(model_path, map_location="cpu")
20
-
21
- if "state_dict" in checkpoint:
22
- checkpoint = checkpoint["state_dict"]
23
-
24
- state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
25
- model.load_state_dict(state_dict, strict=False)
26
 
 
 
 
 
 
 
27
  model.eval()
28
-
29
-
30
- # Image preprocessing (CheXpert standard)
31
- transform = transforms.Compose([
32
- transforms.Resize((224, 224)),
33
- transforms.ToTensor(),
34
- transforms.Normalize(
35
- mean=[0.485, 0.456, 0.406],
36
- std=[0.229, 0.224, 0.225]
37
- )
38
- ])
39
-
40
- LABELS = [
41
- "Atelectasis", "Cardiomegaly", "Consolidation", "Edema",
42
- "Enlarged Cardiomediastinum", "Fracture", "Lung Lesion",
43
- "Lung Opacity", "No Finding", "Pleural Effusion",
44
- "Pleural Other", "Pneumonia", "Pneumothorax", "Support Devices"
45
- ]
46
-
47
- @app.get("/")
48
- def health():
49
- return {"status": "CheXpert DenseNet121 is running"}
50
-
51
- @app.post("/predict")
52
- async def predict(file: UploadFile = File(...)):
53
- image = Image.open(io.BytesIO(await file.read())).convert("RGB")
54
- image = transform(image).unsqueeze(0)
55
-
56
- with torch.no_grad():
57
- logits = model(image)
58
- probs = torch.sigmoid(logits)[0]
59
-
60
- results = {
61
- LABELS[i]: float(probs[i])
62
- for i in range(len(LABELS))
63
- }
64
-
65
- return results
 
 
 
 
 
 
1
  from huggingface_hub import hf_hub_download
2
+ from safetensors.torch import load_file
 
 
3
 
4
  MODEL_REPO = "itsomk/chexpert-densenet121"
5
+ MODEL_FILE = "chexpert_pytorch.safetensors"
6
 
7
+ # Download weights file
8
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
9
 
10
+ # Load the safetensors weights
11
+ state = load_file(model_path)
 
 
 
 
 
 
 
 
12
 
13
+ # Build the DenseNet121 model and load weights
14
+ import torchvision.models as models
15
+ model = models.densenet121(pretrained=False)
16
+ num_features = model.classifier.in_features
17
+ model.classifier = torch.nn.Linear(num_features, 14)
18
+ model.load_state_dict(state, strict=False)
19
  model.eval()