FoodClassifier-ViT / utils /predictor.py
mahmoudalrefaey's picture
Upload 4 files
42a7d1b verified
"""
Prediction utility for FoodViT
Handles model inference and prediction logic
"""
import torch
import numpy as np
from typing import Dict, Tuple, Optional
from config import CLASS_CONFIG
from utils.model_loader import model_loader
from utils.image_processor import image_processor
class FoodPredictor:
"""Class to handle food classification predictions"""
def __init__(self):
self.model = None
self.feature_extractor = None
self.device = None
self.class_names = CLASS_CONFIG["class_names"]
self.id2label = CLASS_CONFIG["id2label"]
def initialize(self):
"""Initialize the predictor with loaded model and feature extractor"""
try:
# Load model and feature extractor
if not model_loader.load_model():
return False
if not model_loader.load_feature_extractor():
return False
self.model = model_loader.get_model()
self.feature_extractor = model_loader.get_feature_extractor()
self.device = model_loader.get_device()
print("Predictor initialized successfully")
return True
except Exception as e:
print(f"Error initializing predictor: {e}")
return False
def predict(self, image_input) -> Dict:
"""
Predict food class for given image
Args:
image_input: Image path, PIL Image, or numpy array
Returns:
dict: Prediction results with class, confidence, and probabilities
"""
try:
if self.model is None:
return {"error": "Model not initialized"}
# Preprocess image
processed_image = image_processor.preprocess_image(image_input)
if processed_image is None:
return {"error": "Failed to preprocess image"}
# Move to device
processed_image = processed_image.to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model(processed_image)
logits = outputs.logits
# Get probabilities
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_class].item()
# Get all class probabilities
all_probabilities = probabilities[0].cpu().numpy()
# Create result dictionary
result = {
"class": self.id2label[predicted_class],
"class_id": predicted_class,
"confidence": confidence,
"probabilities": {
self.id2label[i]: float(all_probabilities[i])
for i in range(len(self.class_names))
},
"success": True
}
return result
except Exception as e:
print(f"Error during prediction: {e}")
return {"error": str(e), "success": False}
def predict_batch(self, image_inputs) -> list:
"""
Predict food classes for multiple images
Args:
image_inputs: List of image inputs
Returns:
list: List of prediction results
"""
results = []
for image_input in image_inputs:
result = self.predict(image_input)
results.append(result)
return results
def get_model_info(self) -> Dict:
"""Get information about the loaded model"""
if self.model is None:
return {"error": "Model not loaded"}
try:
total_params = sum(p.numel() for p in self.model.parameters())
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
return {
"device": str(self.device),
"total_parameters": total_params,
"trainable_parameters": trainable_params,
"num_classes": len(self.class_names),
"class_names": self.class_names
}
except Exception as e:
return {"error": str(e)}
# Global predictor instance
predictor = FoodPredictor()