saba2000 commited on
Commit
c2882e7
·
verified ·
1 Parent(s): 8ad1547

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -13,13 +13,20 @@ target_labels = ['Pneumonia', 'Consolidation', 'Edema']
13
  target_idxs = [labels.index(lbl) for lbl in target_labels]
14
 
15
  def predict(image):
16
- image = image.convert("RGB").resize((224, 224))
 
 
 
 
 
17
  inputs = processor(images=image, return_tensors="pt")
 
18
  with torch.no_grad():
19
- logits = model(**inputs).logits
20
- probs = torch.sigmoid(logits).squeeze()
 
 
21
 
22
- ```
23
  detected = []
24
  results = []
25
  for idx, lbl in zip(target_idxs, target_labels):
 
13
  target_idxs = [labels.index(lbl) for lbl in target_labels]
14
 
15
  def predict(image):
16
+ # Make sure image is RGB
17
+ if image.mode != "RGB":
18
+ image = image.convert("RGB")
19
+
20
+ ```
21
+ # Process the image properly
22
  inputs = processor(images=image, return_tensors="pt")
23
+
24
  with torch.no_grad():
25
+ logits = model(**inputs).logits
26
+
27
+ # Keep batch dimension for safety
28
+ probs = torch.sigmoid(logits)[0] # [batch, num_labels] -> [num_labels]
29
 
 
30
  detected = []
31
  results = []
32
  for idx, lbl in zip(target_idxs, target_labels):