File size: 2,823 Bytes
4c7c089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""

Model loader utility for FoodViT

Handles loading the trained PyTorch model and feature extractor

"""

import torch
import os
from transformers import ViTForImageClassification, ViTFeatureExtractor
from config import MODEL_CONFIG, CLASS_CONFIG
from huggingface_hub import hf_hub_download

class ModelLoader:
    """Class to handle model loading and initialization"""
    
    def __init__(self):
        self.model = None
        self.feature_extractor = None
        self.device = MODEL_CONFIG["device"]
        
    def load_model(self):
        """Load the trained PyTorch model from Hugging Face Hub"""
        try:
            # Download the model from the Hugging Face Hub
            model_path = hf_hub_download(
                repo_id="mahmoudalrefaey/FoodViT-weights",
                filename="bestViT_PT.pth"
            )
            from transformers import ViTForImageClassification
            self.model = ViTForImageClassification.from_pretrained(
                MODEL_CONFIG["feature_extractor_name"],
                num_labels=MODEL_CONFIG["num_labels"],
                ignore_mismatched_sizes=True
            )
            checkpoint = torch.load(
                model_path,
                map_location=self.device,
                weights_only=False
            )
            if hasattr(checkpoint, 'state_dict'):
                state_dict = checkpoint.state_dict()
            elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint
            self.model.load_state_dict(state_dict, strict=False)
            self.model.eval()
            self.model.to(self.device)
            print(f"Model loaded successfully on {self.device}")
            return True
        except Exception as e:
            print(f"Error loading model: {e}")
            return False
    
    def load_feature_extractor(self):
        """Load the ViT feature extractor"""
        try:
            self.feature_extractor = ViTFeatureExtractor.from_pretrained(
                MODEL_CONFIG["feature_extractor_name"]
            )
            print("Feature extractor loaded successfully")
            return True
            
        except Exception as e:
            print(f"Error loading feature extractor: {e}")
            return False
    
    def get_model(self):
        """Get the loaded model"""
        return self.model
    
    def get_feature_extractor(self):
        """Get the loaded feature extractor"""
        return self.feature_extractor
    
    def get_device(self):
        """Get the current device"""
        return self.device

# Global model loader instance
model_loader = ModelLoader()