File size: 4,696 Bytes
42a7d1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""

Prediction utility for FoodViT

Handles model inference and prediction logic

"""

import torch
import numpy as np
from typing import Dict, Tuple, Optional
from config import CLASS_CONFIG
from utils.model_loader import model_loader
from utils.image_processor import image_processor

class FoodPredictor:
    """Class to handle food classification predictions"""
    
    def __init__(self):
        self.model = None
        self.feature_extractor = None
        self.device = None
        self.class_names = CLASS_CONFIG["class_names"]
        self.id2label = CLASS_CONFIG["id2label"]
        
    def initialize(self):
        """Initialize the predictor with loaded model and feature extractor"""
        try:
            # Load model and feature extractor
            if not model_loader.load_model():
                return False
            if not model_loader.load_feature_extractor():
                return False
            
            self.model = model_loader.get_model()
            self.feature_extractor = model_loader.get_feature_extractor()
            self.device = model_loader.get_device()
            
            print("Predictor initialized successfully")
            return True
            
        except Exception as e:
            print(f"Error initializing predictor: {e}")
            return False
    
    def predict(self, image_input) -> Dict:
        """

        Predict food class for given image

        

        Args:

            image_input: Image path, PIL Image, or numpy array

            

        Returns:

            dict: Prediction results with class, confidence, and probabilities

        """
        try:
            if self.model is None:
                return {"error": "Model not initialized"}
            
            # Preprocess image
            processed_image = image_processor.preprocess_image(image_input)
            if processed_image is None:
                return {"error": "Failed to preprocess image"}
            
            # Move to device
            processed_image = processed_image.to(self.device)
            
            # Run inference
            with torch.no_grad():
                outputs = self.model(processed_image)
                logits = outputs.logits
                
                # Get probabilities
                probabilities = torch.softmax(logits, dim=1)
                predicted_class = torch.argmax(probabilities, dim=1).item()
                confidence = probabilities[0][predicted_class].item()
                
                # Get all class probabilities
                all_probabilities = probabilities[0].cpu().numpy()
                
                # Create result dictionary
                result = {
                    "class": self.id2label[predicted_class],
                    "class_id": predicted_class,
                    "confidence": confidence,
                    "probabilities": {
                        self.id2label[i]: float(all_probabilities[i])
                        for i in range(len(self.class_names))
                    },
                    "success": True
                }
                
                return result
                
        except Exception as e:
            print(f"Error during prediction: {e}")
            return {"error": str(e), "success": False}
    
    def predict_batch(self, image_inputs) -> list:
        """

        Predict food classes for multiple images

        

        Args:

            image_inputs: List of image inputs

            

        Returns:

            list: List of prediction results

        """
        results = []
        for image_input in image_inputs:
            result = self.predict(image_input)
            results.append(result)
        return results
    
    def get_model_info(self) -> Dict:
        """Get information about the loaded model"""
        if self.model is None:
            return {"error": "Model not loaded"}
        
        try:
            total_params = sum(p.numel() for p in self.model.parameters())
            trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            
            return {
                "device": str(self.device),
                "total_parameters": total_params,
                "trainable_parameters": trainable_params,
                "num_classes": len(self.class_names),
                "class_names": self.class_names
            }
        except Exception as e:
            return {"error": str(e)}

# Global predictor instance
predictor = FoodPredictor()