Spaces:
Sleeping
Sleeping
File size: 3,278 Bytes
b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 48b3884 b383602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# app/prediction.py
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
from pathlib import Path
import numpy as np
from typing import List, Dict, Union, Any
from .image_utils import add_watermark
ImageType = Union[str, Path, bytes, np.ndarray]
class PredictionPipeline:
def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = ViTImageProcessor.from_pretrained(model_path)
self.model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
self.model.eval()
self.id2label = self.model.config.id2label
def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]:
if not image_sources:
return {"error": "No images provided."}
individual_results = []
all_logits = []
valid_images_as_np = []
for source in image_sources:
try:
if isinstance(source, np.ndarray):
image = Image.fromarray(source).convert("RGB")
else:
image = Image.open(source).convert("RGB")
valid_images_as_np.append(np.array(image))
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
all_logits.append(logits)
# --- NEW: Calculate individual prediction ---
ind_probs = torch.nn.functional.softmax(logits, dim=-1)
ind_conf, ind_idx = torch.max(ind_probs, dim=-1)
individual_results.append({
"prediction": self.id2label[ind_idx.item()],
"confidence": ind_conf.item()
})
except Exception as e:
print(f"Skipping a corrupted or invalid image file. Error: {e}")
individual_results.append({"prediction": "Error", "confidence": 0})
continue
if not all_logits:
return {"error": "All images were invalid."}
# --- Aggregate Prediction (same as before) ---
avg_logits = torch.mean(torch.stack(all_logits), dim=0)
probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
final_prediction = self.id2label[predicted_class_idx.item()]
final_confidence = confidence_score.item()
# --- NEW: Watermark images with their INDIVIDUAL results ---
watermarked_images = [
add_watermark(img_np, res["prediction"], res["confidence"])
for img_np, res in zip(valid_images_as_np, individual_results)
if res["prediction"] != "Error"
]
return {
"final_prediction": final_prediction,
"final_confidence": final_confidence,
"individual_results": individual_results,
"watermarked_images": watermarked_images
} |