Spaces:
Sleeping
Sleeping
File size: 2,823 Bytes
4c7c089 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
"""
Model loader utility for FoodViT
Handles loading the trained PyTorch model and feature extractor
"""
import torch
import os
from transformers import ViTForImageClassification, ViTFeatureExtractor
from config import MODEL_CONFIG, CLASS_CONFIG
from huggingface_hub import hf_hub_download
class ModelLoader:
"""Class to handle model loading and initialization"""
def __init__(self):
self.model = None
self.feature_extractor = None
self.device = MODEL_CONFIG["device"]
def load_model(self):
"""Load the trained PyTorch model from Hugging Face Hub"""
try:
# Download the model from the Hugging Face Hub
model_path = hf_hub_download(
repo_id="mahmoudalrefaey/FoodViT-weights",
filename="bestViT_PT.pth"
)
from transformers import ViTForImageClassification
self.model = ViTForImageClassification.from_pretrained(
MODEL_CONFIG["feature_extractor_name"],
num_labels=MODEL_CONFIG["num_labels"],
ignore_mismatched_sizes=True
)
checkpoint = torch.load(
model_path,
map_location=self.device,
weights_only=False
)
if hasattr(checkpoint, 'state_dict'):
state_dict = checkpoint.state_dict()
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
self.model.load_state_dict(state_dict, strict=False)
self.model.eval()
self.model.to(self.device)
print(f"Model loaded successfully on {self.device}")
return True
except Exception as e:
print(f"Error loading model: {e}")
return False
def load_feature_extractor(self):
"""Load the ViT feature extractor"""
try:
self.feature_extractor = ViTFeatureExtractor.from_pretrained(
MODEL_CONFIG["feature_extractor_name"]
)
print("Feature extractor loaded successfully")
return True
except Exception as e:
print(f"Error loading feature extractor: {e}")
return False
def get_model(self):
"""Get the loaded model"""
return self.model
def get_feature_extractor(self):
"""Get the loaded feature extractor"""
return self.feature_extractor
def get_device(self):
"""Get the current device"""
return self.device
# Global model loader instance
model_loader = ModelLoader() |