# app/prediction.py (Final Version with Relaxed Sanity Check) import torch from transformers import ViTImageProcessor, ViTForImageClassification, AutoImageProcessor, ResNetForImageClassification from PIL import Image from pathlib import Path import numpy as np from typing import List, Dict, Union, Any from .image_utils import add_watermark ImageType = Union[str, Path, bytes, np.ndarray] # A list of obviously non-medical terms to check against FORBIDDEN_LABELS = [ "car", "truck", "van", "motorcycle", "bicycle", "bus", "train", "boat", "airplane", "cat", "dog", "bird", "horse", "sheep", "cow", "bear", "zebra", "giraffe", "landscape", "mountain", "beach", "forest", "building", "house", "road", "street", "computer", "keyboard", "mouse", "laptop", "cellphone", "television", "food", "plate", "bowl", "cup", "fork", "knife", "spoon" ] class PredictionPipeline: def __init__(self, model_path: Path = Path("artifacts/model_training/model")): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.pneumonia_processor = ViTImageProcessor.from_pretrained(model_path) self.pneumonia_model = ViTForImageClassification.from_pretrained(model_path).to(self.device) self.pneumonia_model.eval() self.id2label = self.pneumonia_model.config.id2label self.sanity_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") self.sanity_model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50").to(self.device) self.sanity_model.eval() def sanity_check(self, image: Image.Image) -> bool: """ Uses a general-purpose model to check if the image is something obviously not a medical scan. Returns True if the image is plausible, False otherwise. """ with torch.no_grad(): inputs = self.sanity_processor(images=image, return_tensors="pt").to(self.device) outputs = self.sanity_model(**inputs) logits = outputs.logits top5_indices = torch.topk(logits, 5).indices[0] for idx in top5_indices: label = self.sanity_model.config.id2label[idx.item()].lower() # Check for partial matches (e.g., 'sports car', 'fire truck') for forbidden in FORBIDDEN_LABELS: if forbidden in label: print(f"Sanity check FAILED: Image classified as '{label}', which contains a forbidden term '{forbidden}'.") return False # It's definitely not an X-ray print("Sanity check PASSED: Image does not appear to be a common non-medical object.") return True # It's plausible enough to proceed def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]: if not image_sources: return {"error": "No images provided."} individual_results = [] all_logits = [] valid_images_as_np = [] for source in image_sources: try: if isinstance(source, np.ndarray): image = Image.fromarray(source).convert("RGB") else: image = Image.open(source).convert("RGB") # --- NEW: Perform the relaxed sanity check --- if not self.sanity_check(image): raise ValueError("Image appears to be a common object, not a medical scan.") valid_images_as_np.append(np.array(image)) # ... (rest of the prediction logic is the same) inputs = self.pneumonia_processor(images=image, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.pneumonia_model(**inputs) logits = outputs.logits all_logits.append(logits) ind_probs = torch.nn.functional.softmax(logits, dim=-1); ind_conf, ind_idx = torch.max(ind_probs, dim=-1) individual_results.append({"prediction": self.id2label[ind_idx.item()], "confidence": ind_conf.item()}) except Exception as e: print(f"Skipping an invalid image file. Error: {e}") individual_results.append({"prediction": "Error", "confidence": 0}) continue if not all_logits: return {"error": "Invalid Image", "details": "All uploaded files were invalid or did not appear to be chest X-rays. Please upload a clear, frontal chest X-ray image."} # ... (Aggregate prediction and watermarking are the same) avg_logits = torch.mean(torch.stack(all_logits), dim=0) probabilities = torch.nn.functional.softmax(avg_logits, dim=-1) confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1) final_prediction = self.id2label[predicted_class_idx.item()] final_confidence = confidence_score.item() # NOTE: The low-confidence check has been removed as the sanity check is more robust. watermarked_images = [ add_watermark(img_np, res["prediction"], res["confidence"]) for img_np, res in zip(valid_images_as_np, individual_results) if res["prediction"] != "Error" ] return { "final_prediction": final_prediction, "final_confidence": final_confidence, "individual_results": individual_results, "watermarked_images": watermarked_images }