Pneumonia-Detection-AI / app /prediction.py
ALYYAN's picture
Prepare application for deployment
48b3884
raw
history blame
3.28 kB
# app/prediction.py
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
from pathlib import Path
import numpy as np
from typing import List, Dict, Union, Any
from .image_utils import add_watermark
ImageType = Union[str, Path, bytes, np.ndarray]
class PredictionPipeline:
def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = ViTImageProcessor.from_pretrained(model_path)
self.model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
self.model.eval()
self.id2label = self.model.config.id2label
def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]:
if not image_sources:
return {"error": "No images provided."}
individual_results = []
all_logits = []
valid_images_as_np = []
for source in image_sources:
try:
if isinstance(source, np.ndarray):
image = Image.fromarray(source).convert("RGB")
else:
image = Image.open(source).convert("RGB")
valid_images_as_np.append(np.array(image))
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
all_logits.append(logits)
# --- NEW: Calculate individual prediction ---
ind_probs = torch.nn.functional.softmax(logits, dim=-1)
ind_conf, ind_idx = torch.max(ind_probs, dim=-1)
individual_results.append({
"prediction": self.id2label[ind_idx.item()],
"confidence": ind_conf.item()
})
except Exception as e:
print(f"Skipping a corrupted or invalid image file. Error: {e}")
individual_results.append({"prediction": "Error", "confidence": 0})
continue
if not all_logits:
return {"error": "All images were invalid."}
# --- Aggregate Prediction (same as before) ---
avg_logits = torch.mean(torch.stack(all_logits), dim=0)
probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
final_prediction = self.id2label[predicted_class_idx.item()]
final_confidence = confidence_score.item()
# --- NEW: Watermark images with their INDIVIDUAL results ---
watermarked_images = [
add_watermark(img_np, res["prediction"], res["confidence"])
for img_np, res in zip(valid_images_as_np, individual_results)
if res["prediction"] != "Error"
]
return {
"final_prediction": final_prediction,
"final_confidence": final_confidence,
"individual_results": individual_results,
"watermarked_images": watermarked_images
}