ALYYAN commited on
Commit
3726151
·
unverified ·
1 Parent(s): 4f6f62e

Update prediction.py

Browse files
Files changed (1) hide show
  1. app/prediction.py +45 -20
app/prediction.py CHANGED
@@ -1,7 +1,7 @@
1
  # app/prediction.py
2
 
3
  import torch
4
- from transformers import ViTImageProcessor, ViTForImageClassification
5
  from PIL import Image
6
  from pathlib import Path
7
  import numpy as np
@@ -13,10 +13,41 @@ ImageType = Union[str, Path, bytes, np.ndarray]
13
  class PredictionPipeline:
14
  def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
15
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
- self.processor = ViTImageProcessor.from_pretrained(model_path)
17
- self.model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
18
- self.model.eval()
19
- self.id2label = self.model.config.id2label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]:
22
  if not image_sources:
@@ -33,15 +64,18 @@ class PredictionPipeline:
33
  else:
34
  image = Image.open(source).convert("RGB")
35
 
 
 
 
 
36
  valid_images_as_np.append(np.array(image))
37
 
38
- inputs = self.processor(images=image, return_tensors="pt").to(self.device)
39
  with torch.no_grad():
40
- outputs = self.model(**inputs)
41
  logits = outputs.logits
42
  all_logits.append(logits)
43
 
44
- # --- NEW: Calculate individual prediction ---
45
  ind_probs = torch.nn.functional.softmax(logits, dim=-1)
46
  ind_conf, ind_idx = torch.max(ind_probs, dim=-1)
47
  individual_results.append({
@@ -50,13 +84,14 @@ class PredictionPipeline:
50
  })
51
 
52
  except Exception as e:
53
- print(f"Skipping a corrupted or invalid image file. Error: {e}")
54
  individual_results.append({"prediction": "Error", "confidence": 0})
55
  continue
56
 
57
  if not all_logits:
58
- return {"error": "All images were invalid."}
59
 
 
60
  avg_logits = torch.mean(torch.stack(all_logits), dim=0)
61
  probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
62
  confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
@@ -64,15 +99,6 @@ class PredictionPipeline:
64
  final_prediction = self.id2label[predicted_class_idx.item()]
65
  final_confidence = confidence_score.item()
66
 
67
- # --- NEW: Add confidence check ---
68
- if final_confidence < 0.60:
69
- return {
70
- "error": "Low Confidence Prediction",
71
- "details": f"The model's confidence of {final_confidence:.1%} is too low. "
72
- "Please ensure the uploaded image is a clear, frontal chest X-ray."
73
- }
74
-
75
- # --- Watermarking (same as before) ---
76
  watermarked_images = [
77
  add_watermark(img_np, res["prediction"], res["confidence"])
78
  for img_np, res in zip(valid_images_as_np, individual_results)
@@ -85,4 +111,3 @@ class PredictionPipeline:
85
  "individual_results": individual_results,
86
  "watermarked_images": watermarked_images
87
  }
88
-
 
1
  # app/prediction.py
2
 
3
  import torch
4
+ from transformers import ViTImageProcessor, ViTForImageClassification, AutoImageProcessor, ResNetForImageClassification
5
  from PIL import Image
6
  from pathlib import Path
7
  import numpy as np
 
13
  class PredictionPipeline:
14
  def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
15
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ # --- Pneumonia Model (our fine-tuned model) ---
18
+ self.pneumonia_processor = ViTImageProcessor.from_pretrained(model_path)
19
+ self.pneumonia_model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
20
+ self.pneumonia_model.eval()
21
+ self.id2label = self.pneumonia_model.config.id2label
22
+
23
+ # --- Sanity Check Model (general purpose) ---
24
+ # This model knows what many things are, including X-rays.
25
+ self.sanity_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
26
+ self.sanity_model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50").to(self.device)
27
+ self.sanity_model.eval()
28
+
29
+ def is_likely_xray(self, image: Image.Image) -> bool:
30
+ """
31
+ Uses the general-purpose ResNet-50 model to check if the image
32
+ is likely a chest X-ray.
33
+ """
34
+ with torch.no_grad():
35
+ inputs = self.sanity_processor(images=image, return_tensors="pt").to(self.device)
36
+ outputs = self.sanity_model(**inputs)
37
+ logits = outputs.logits
38
+
39
+ # Get the top 5 predicted classes
40
+ top5_probs, top5_indices = torch.topk(logits.softmax(-1), 5)
41
+
42
+ # The model's labels are in its config. We look for 'x-ray' or 'chest'.
43
+ for idx in top5_indices[0]:
44
+ label = self.sanity_model.config.id2label[idx.item()].lower()
45
+ if "x-ray" in label or "chest" in label or "radiograph" in label:
46
+ print(f"Sanity check passed: Image classified as '{label}'")
47
+ return True
48
+
49
+ print("Sanity check failed: Image is not classified as an X-ray.")
50
+ return False
51
 
52
  def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]:
53
  if not image_sources:
 
64
  else:
65
  image = Image.open(source).convert("RGB")
66
 
67
+ # --- NEW: Perform the sanity check first! ---
68
+ if not self.is_likely_xray(image):
69
+ raise ValueError("Image does not appear to be a chest X-ray.")
70
+
71
  valid_images_as_np.append(np.array(image))
72
 
73
+ inputs = self.pneumonia_processor(images=image, return_tensors="pt").to(self.device)
74
  with torch.no_grad():
75
+ outputs = self.pneumonia_model(**inputs)
76
  logits = outputs.logits
77
  all_logits.append(logits)
78
 
 
79
  ind_probs = torch.nn.functional.softmax(logits, dim=-1)
80
  ind_conf, ind_idx = torch.max(ind_probs, dim=-1)
81
  individual_results.append({
 
84
  })
85
 
86
  except Exception as e:
87
+ print(f"Skipping an invalid image file. Error: {e}")
88
  individual_results.append({"prediction": "Error", "confidence": 0})
89
  continue
90
 
91
  if not all_logits:
92
+ return {"error": "Invalid Image", "details": "All uploaded files were invalid or did not appear to be chest X-rays. Please upload a clear, frontal chest X-ray image."}
93
 
94
+ # ... (Aggregate prediction and watermarking are the same) ...
95
  avg_logits = torch.mean(torch.stack(all_logits), dim=0)
96
  probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
97
  confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
 
99
  final_prediction = self.id2label[predicted_class_idx.item()]
100
  final_confidence = confidence_score.item()
101
 
 
 
 
 
 
 
 
 
 
102
  watermarked_images = [
103
  add_watermark(img_np, res["prediction"], res["confidence"])
104
  for img_np, res in zip(valid_images_as_np, individual_results)
 
111
  "individual_results": individual_results,
112
  "watermarked_images": watermarked_images
113
  }