Spaces:
Running
Running
File size: 5,044 Bytes
b383602 3726151 b383602 48b3884 b383602 48b3884 b383602 3726151 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 3726151 48b3884 b383602 3726151 b383602 3726151 48b3884 b383602 3726151 48b3884 b383602 3726151 b383602 3726151 b383602 48b3884 b383602 48b3884 cbf6689 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# app/prediction.py
import torch
from transformers import ViTImageProcessor, ViTForImageClassification, AutoImageProcessor, ResNetForImageClassification
from PIL import Image
from pathlib import Path
import numpy as np
from typing import List, Dict, Union, Any
from .image_utils import add_watermark
ImageType = Union[str, Path, bytes, np.ndarray]
class PredictionPipeline:
def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# --- Pneumonia Model (our fine-tuned model) ---
self.pneumonia_processor = ViTImageProcessor.from_pretrained(model_path)
self.pneumonia_model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
self.pneumonia_model.eval()
self.id2label = self.pneumonia_model.config.id2label
# --- Sanity Check Model (general purpose) ---
# This model knows what many things are, including X-rays.
self.sanity_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
self.sanity_model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50").to(self.device)
self.sanity_model.eval()
def is_likely_xray(self, image: Image.Image) -> bool:
"""
Uses the general-purpose ResNet-50 model to check if the image
is likely a chest X-ray.
"""
with torch.no_grad():
inputs = self.sanity_processor(images=image, return_tensors="pt").to(self.device)
outputs = self.sanity_model(**inputs)
logits = outputs.logits
# Get the top 5 predicted classes
top5_probs, top5_indices = torch.topk(logits.softmax(-1), 5)
# The model's labels are in its config. We look for 'x-ray' or 'chest'.
for idx in top5_indices[0]:
label = self.sanity_model.config.id2label[idx.item()].lower()
if "x-ray" in label or "chest" in label or "radiograph" in label:
print(f"Sanity check passed: Image classified as '{label}'")
return True
print("Sanity check failed: Image is not classified as an X-ray.")
return False
def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]:
if not image_sources:
return {"error": "No images provided."}
individual_results = []
all_logits = []
valid_images_as_np = []
for source in image_sources:
try:
if isinstance(source, np.ndarray):
image = Image.fromarray(source).convert("RGB")
else:
image = Image.open(source).convert("RGB")
# --- NEW: Perform the sanity check first! ---
if not self.is_likely_xray(image):
raise ValueError("Image does not appear to be a chest X-ray.")
valid_images_as_np.append(np.array(image))
inputs = self.pneumonia_processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.pneumonia_model(**inputs)
logits = outputs.logits
all_logits.append(logits)
ind_probs = torch.nn.functional.softmax(logits, dim=-1)
ind_conf, ind_idx = torch.max(ind_probs, dim=-1)
individual_results.append({
"prediction": self.id2label[ind_idx.item()],
"confidence": ind_conf.item()
})
except Exception as e:
print(f"Skipping an invalid image file. Error: {e}")
individual_results.append({"prediction": "Error", "confidence": 0})
continue
if not all_logits:
return {"error": "Invalid Image", "details": "All uploaded files were invalid or did not appear to be chest X-rays. Please upload a clear, frontal chest X-ray image."}
# ... (Aggregate prediction and watermarking are the same) ...
avg_logits = torch.mean(torch.stack(all_logits), dim=0)
probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
final_prediction = self.id2label[predicted_class_idx.item()]
final_confidence = confidence_score.item()
watermarked_images = [
add_watermark(img_np, res["prediction"], res["confidence"])
for img_np, res in zip(valid_images_as_np, individual_results)
if res["prediction"] != "Error"
]
return {
"final_prediction": final_prediction,
"final_confidence": final_confidence,
"individual_results": individual_results,
"watermarked_images": watermarked_images
}
|