Spaces:
Sleeping
Sleeping
File size: 3,559 Bytes
b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 cbf6689 48b3884 b383602 48b3884 cbf6689 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# app/prediction.py
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
from pathlib import Path
import numpy as np
from typing import List, Dict, Union, Any
from .image_utils import add_watermark
ImageType = Union[str, Path, bytes, np.ndarray]
class PredictionPipeline:
def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = ViTImageProcessor.from_pretrained(model_path)
self.model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
self.model.eval()
self.id2label = self.model.config.id2label
def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]:
if not image_sources:
return {"error": "No images provided."}
individual_results = []
all_logits = []
valid_images_as_np = []
for source in image_sources:
try:
if isinstance(source, np.ndarray):
image = Image.fromarray(source).convert("RGB")
else:
image = Image.open(source).convert("RGB")
valid_images_as_np.append(np.array(image))
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
all_logits.append(logits)
# --- NEW: Calculate individual prediction ---
ind_probs = torch.nn.functional.softmax(logits, dim=-1)
ind_conf, ind_idx = torch.max(ind_probs, dim=-1)
individual_results.append({
"prediction": self.id2label[ind_idx.item()],
"confidence": ind_conf.item()
})
except Exception as e:
print(f"Skipping a corrupted or invalid image file. Error: {e}")
individual_results.append({"prediction": "Error", "confidence": 0})
continue
if not all_logits:
return {"error": "All images were invalid."}
avg_logits = torch.mean(torch.stack(all_logits), dim=0)
probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
final_prediction = self.id2label[predicted_class_idx.item()]
final_confidence = confidence_score.item()
# --- NEW: Add confidence check ---
if final_confidence < 0.60:
return {
"error": "Low Confidence Prediction",
"details": f"The model's confidence of {final_confidence:.1%} is too low. "
"Please ensure the uploaded image is a clear, frontal chest X-ray."
}
# --- Watermarking (same as before) ---
watermarked_images = [
add_watermark(img_np, res["prediction"], res["confidence"])
for img_np, res in zip(valid_images_as_np, individual_results)
if res["prediction"] != "Error"
]
return {
"final_prediction": final_prediction,
"final_confidence": final_confidence,
"individual_results": individual_results,
"watermarked_images": watermarked_images
}
|