mahmoudalrefaey commited on
Commit
4c7c089
·
verified ·
1 Parent(s): 5fab1e5

Upload model_loader.py

Browse files
Files changed (1) hide show
  1. utils/model_loader.py +80 -0
utils/model_loader.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model loader utility for FoodViT
3
+ Handles loading the trained PyTorch model and feature extractor
4
+ """
5
+
6
+ import torch
7
+ import os
8
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
9
+ from config import MODEL_CONFIG, CLASS_CONFIG
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ class ModelLoader:
13
+ """Class to handle model loading and initialization"""
14
+
15
+ def __init__(self):
16
+ self.model = None
17
+ self.feature_extractor = None
18
+ self.device = MODEL_CONFIG["device"]
19
+
20
+ def load_model(self):
21
+ """Load the trained PyTorch model from Hugging Face Hub"""
22
+ try:
23
+ # Download the model from the Hugging Face Hub
24
+ model_path = hf_hub_download(
25
+ repo_id="mahmoudalrefaey/FoodViT-weights",
26
+ filename="bestViT_PT.pth"
27
+ )
28
+ from transformers import ViTForImageClassification
29
+ self.model = ViTForImageClassification.from_pretrained(
30
+ MODEL_CONFIG["feature_extractor_name"],
31
+ num_labels=MODEL_CONFIG["num_labels"],
32
+ ignore_mismatched_sizes=True
33
+ )
34
+ checkpoint = torch.load(
35
+ model_path,
36
+ map_location=self.device,
37
+ weights_only=False
38
+ )
39
+ if hasattr(checkpoint, 'state_dict'):
40
+ state_dict = checkpoint.state_dict()
41
+ elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
42
+ state_dict = checkpoint['state_dict']
43
+ else:
44
+ state_dict = checkpoint
45
+ self.model.load_state_dict(state_dict, strict=False)
46
+ self.model.eval()
47
+ self.model.to(self.device)
48
+ print(f"Model loaded successfully on {self.device}")
49
+ return True
50
+ except Exception as e:
51
+ print(f"Error loading model: {e}")
52
+ return False
53
+
54
+ def load_feature_extractor(self):
55
+ """Load the ViT feature extractor"""
56
+ try:
57
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained(
58
+ MODEL_CONFIG["feature_extractor_name"]
59
+ )
60
+ print("Feature extractor loaded successfully")
61
+ return True
62
+
63
+ except Exception as e:
64
+ print(f"Error loading feature extractor: {e}")
65
+ return False
66
+
67
+ def get_model(self):
68
+ """Get the loaded model"""
69
+ return self.model
70
+
71
+ def get_feature_extractor(self):
72
+ """Get the loaded feature extractor"""
73
+ return self.feature_extractor
74
+
75
+ def get_device(self):
76
+ """Get the current device"""
77
+ return self.device
78
+
79
+ # Global model loader instance
80
+ model_loader = ModelLoader()