sagarmee commited on
Commit
baa4bc1
·
verified ·
1 Parent(s): d7129cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -16
app.py CHANGED
@@ -2,19 +2,28 @@ from fastapi import FastAPI, UploadFile, File
2
  from PIL import Image
3
  import torch
4
  import torchvision.transforms as transforms
5
- from transformers import AutoModelForImageClassification
 
6
  import io
7
 
8
  app = FastAPI()
9
 
10
- MODEL_NAME = "itsomk/chexpert-densenet121"
 
11
 
12
- # Load model
13
- model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
14
  model.eval()
15
- model.to("cpu")
16
 
17
- # Image preprocessing (manual, because no image processor exists)
18
  transform = transforms.Compose([
19
  transforms.Resize((224, 224)),
20
  transforms.ToTensor(),
@@ -24,20 +33,29 @@ transform = transforms.Compose([
24
  )
25
  ])
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  @app.post("/predict")
28
  async def predict(file: UploadFile = File(...)):
29
- image_bytes = await file.read()
30
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
31
-
32
- image_tensor = transform(image).unsqueeze(0)
33
 
34
  with torch.no_grad():
35
- outputs = model(image_tensor)
36
- probs = torch.sigmoid(outputs.logits)[0]
37
 
38
- predictions = {
39
- model.config.id2label[i]: float(probs[i])
40
- for i in range(len(probs))
41
  }
42
 
43
- return {"predictions": predictions}
 
2
  from PIL import Image
3
  import torch
4
  import torchvision.transforms as transforms
5
+ import torchvision.models as models
6
+ from huggingface_hub import hf_hub_download
7
  import io
8
 
9
  app = FastAPI()
10
 
11
+ MODEL_REPO = "itsomk/chexpert-densenet121"
12
+ MODEL_FILE = "pytorch_model.bin"
13
 
14
+ # Download weights
15
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
16
+
17
+ # Load DenseNet121
18
+ model = models.densenet121(pretrained=False)
19
+ num_classes = 14
20
+ model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
21
+
22
+ state_dict = torch.load(model_path, map_location="cpu")
23
+ model.load_state_dict(state_dict)
24
  model.eval()
 
25
 
26
+ # Image preprocessing (CheXpert standard)
27
  transform = transforms.Compose([
28
  transforms.Resize((224, 224)),
29
  transforms.ToTensor(),
 
33
  )
34
  ])
35
 
36
+ LABELS = [
37
+ "Atelectasis", "Cardiomegaly", "Consolidation", "Edema",
38
+ "Enlarged Cardiomediastinum", "Fracture", "Lung Lesion",
39
+ "Lung Opacity", "No Finding", "Pleural Effusion",
40
+ "Pleural Other", "Pneumonia", "Pneumothorax", "Support Devices"
41
+ ]
42
+
43
+ @app.get("/")
44
+ def health():
45
+ return {"status": "CheXpert DenseNet121 is running"}
46
+
47
  @app.post("/predict")
48
  async def predict(file: UploadFile = File(...)):
49
+ image = Image.open(io.BytesIO(await file.read())).convert("RGB")
50
+ image = transform(image).unsqueeze(0)
 
 
51
 
52
  with torch.no_grad():
53
+ logits = model(image)
54
+ probs = torch.sigmoid(logits)[0]
55
 
56
+ results = {
57
+ LABELS[i]: float(probs[i])
58
+ for i in range(len(LABELS))
59
  }
60
 
61
+ return results