ALYYAN commited on
Commit
131eab2
·
unverified ·
1 Parent(s): cd5d737

Update prediction.py

Browse files
Files changed (1) hide show
  1. app/prediction.py +31 -29
app/prediction.py CHANGED
@@ -1,4 +1,4 @@
1
- # app/prediction.py
2
 
3
  import torch
4
  from transformers import ViTImageProcessor, ViTForImageClassification, AutoImageProcessor, ResNetForImageClassification
@@ -10,44 +10,50 @@ from .image_utils import add_watermark
10
 
11
  ImageType = Union[str, Path, bytes, np.ndarray]
12
 
 
 
 
 
 
 
 
 
 
13
  class PredictionPipeline:
14
  def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
15
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- # --- Pneumonia Model (our fine-tuned model) ---
18
  self.pneumonia_processor = ViTImageProcessor.from_pretrained(model_path)
19
  self.pneumonia_model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
20
  self.pneumonia_model.eval()
21
  self.id2label = self.pneumonia_model.config.id2label
22
 
23
- # --- Sanity Check Model (general purpose) ---
24
- # This model knows what many things are, including X-rays.
25
  self.sanity_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
26
  self.sanity_model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50").to(self.device)
27
  self.sanity_model.eval()
28
 
29
- def is_likely_xray(self, image: Image.Image) -> bool:
30
  """
31
- Uses the general-purpose ResNet-50 model to check if the image
32
- is likely a chest X-ray.
33
  """
34
  with torch.no_grad():
35
  inputs = self.sanity_processor(images=image, return_tensors="pt").to(self.device)
36
  outputs = self.sanity_model(**inputs)
37
  logits = outputs.logits
38
 
39
- # Get the top 5 predicted classes
40
- top5_probs, top5_indices = torch.topk(logits.softmax(-1), 5)
41
 
42
- # The model's labels are in its config. We look for 'x-ray' or 'chest'.
43
- for idx in top5_indices[0]:
44
  label = self.sanity_model.config.id2label[idx.item()].lower()
45
- if "x-ray" in label or "chest" in label or "radiograph" in label:
46
- print(f"Sanity check passed: Image classified as '{label}'")
47
- return True
 
 
48
 
49
- print("Sanity check failed: Image is not classified as an X-ray.")
50
- return False
51
 
52
  def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]:
53
  if not image_sources:
@@ -64,24 +70,20 @@ class PredictionPipeline:
64
  else:
65
  image = Image.open(source).convert("RGB")
66
 
67
- # --- NEW: Perform the sanity check first! ---
68
- if not self.is_likely_xray(image):
69
- raise ValueError("Image does not appear to be a chest X-ray.")
70
 
71
  valid_images_as_np.append(np.array(image))
72
 
 
73
  inputs = self.pneumonia_processor(images=image, return_tensors="pt").to(self.device)
74
  with torch.no_grad():
75
  outputs = self.pneumonia_model(**inputs)
76
  logits = outputs.logits
77
  all_logits.append(logits)
78
-
79
- ind_probs = torch.nn.functional.softmax(logits, dim=-1)
80
- ind_conf, ind_idx = torch.max(ind_probs, dim=-1)
81
- individual_results.append({
82
- "prediction": self.id2label[ind_idx.item()],
83
- "confidence": ind_conf.item()
84
- })
85
 
86
  except Exception as e:
87
  print(f"Skipping an invalid image file. Error: {e}")
@@ -91,14 +93,14 @@ class PredictionPipeline:
91
  if not all_logits:
92
  return {"error": "Invalid Image", "details": "All uploaded files were invalid or did not appear to be chest X-rays. Please upload a clear, frontal chest X-ray image."}
93
 
94
- # ... (Aggregate prediction and watermarking are the same) ...
95
  avg_logits = torch.mean(torch.stack(all_logits), dim=0)
96
  probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
97
  confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
98
-
99
  final_prediction = self.id2label[predicted_class_idx.item()]
100
  final_confidence = confidence_score.item()
101
-
 
102
  watermarked_images = [
103
  add_watermark(img_np, res["prediction"], res["confidence"])
104
  for img_np, res in zip(valid_images_as_np, individual_results)
 
1
+ # app/prediction.py (Final Version with Relaxed Sanity Check)
2
 
3
  import torch
4
  from transformers import ViTImageProcessor, ViTForImageClassification, AutoImageProcessor, ResNetForImageClassification
 
10
 
11
  ImageType = Union[str, Path, bytes, np.ndarray]
12
 
13
+ # A list of obviously non-medical terms to check against
14
+ FORBIDDEN_LABELS = [
15
+ "car", "truck", "van", "motorcycle", "bicycle", "bus", "train", "boat", "airplane",
16
+ "cat", "dog", "bird", "horse", "sheep", "cow", "bear", "zebra", "giraffe",
17
+ "landscape", "mountain", "beach", "forest", "building", "house", "road", "street",
18
+ "computer", "keyboard", "mouse", "laptop", "cellphone", "television",
19
+ "food", "plate", "bowl", "cup", "fork", "knife", "spoon"
20
+ ]
21
+
22
  class PredictionPipeline:
23
  def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
24
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
 
 
26
  self.pneumonia_processor = ViTImageProcessor.from_pretrained(model_path)
27
  self.pneumonia_model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
28
  self.pneumonia_model.eval()
29
  self.id2label = self.pneumonia_model.config.id2label
30
 
 
 
31
  self.sanity_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
32
  self.sanity_model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50").to(self.device)
33
  self.sanity_model.eval()
34
 
35
+ def sanity_check(self, image: Image.Image) -> bool:
36
  """
37
+ Uses a general-purpose model to check if the image is something obviously
38
+ not a medical scan. Returns True if the image is plausible, False otherwise.
39
  """
40
  with torch.no_grad():
41
  inputs = self.sanity_processor(images=image, return_tensors="pt").to(self.device)
42
  outputs = self.sanity_model(**inputs)
43
  logits = outputs.logits
44
 
45
+ top5_indices = torch.topk(logits, 5).indices[0]
 
46
 
47
+ for idx in top5_indices:
 
48
  label = self.sanity_model.config.id2label[idx.item()].lower()
49
+ # Check for partial matches (e.g., 'sports car', 'fire truck')
50
+ for forbidden in FORBIDDEN_LABELS:
51
+ if forbidden in label:
52
+ print(f"Sanity check FAILED: Image classified as '{label}', which contains a forbidden term '{forbidden}'.")
53
+ return False # It's definitely not an X-ray
54
 
55
+ print("Sanity check PASSED: Image does not appear to be a common non-medical object.")
56
+ return True # It's plausible enough to proceed
57
 
58
  def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]:
59
  if not image_sources:
 
70
  else:
71
  image = Image.open(source).convert("RGB")
72
 
73
+ # --- NEW: Perform the relaxed sanity check ---
74
+ if not self.sanity_check(image):
75
+ raise ValueError("Image appears to be a common object, not a medical scan.")
76
 
77
  valid_images_as_np.append(np.array(image))
78
 
79
+ # ... (rest of the prediction logic is the same)
80
  inputs = self.pneumonia_processor(images=image, return_tensors="pt").to(self.device)
81
  with torch.no_grad():
82
  outputs = self.pneumonia_model(**inputs)
83
  logits = outputs.logits
84
  all_logits.append(logits)
85
+ ind_probs = torch.nn.functional.softmax(logits, dim=-1); ind_conf, ind_idx = torch.max(ind_probs, dim=-1)
86
+ individual_results.append({"prediction": self.id2label[ind_idx.item()], "confidence": ind_conf.item()})
 
 
 
 
 
87
 
88
  except Exception as e:
89
  print(f"Skipping an invalid image file. Error: {e}")
 
93
  if not all_logits:
94
  return {"error": "Invalid Image", "details": "All uploaded files were invalid or did not appear to be chest X-rays. Please upload a clear, frontal chest X-ray image."}
95
 
96
+ # ... (Aggregate prediction and watermarking are the same)
97
  avg_logits = torch.mean(torch.stack(all_logits), dim=0)
98
  probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
99
  confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
 
100
  final_prediction = self.id2label[predicted_class_idx.item()]
101
  final_confidence = confidence_score.item()
102
+ # NOTE: The low-confidence check has been removed as the sanity check is more robust.
103
+
104
  watermarked_images = [
105
  add_watermark(img_np, res["prediction"], res["confidence"])
106
  for img_np, res in zip(valid_images_as_np, individual_results)