Scenes-Sports-Classifier / utils /model_loader.py
HimankJ's picture
Upload folder using huggingface_hub
b5b42e9 verified
import boto3
import torch
from PIL import Image
import os
import json
from torchvision import transforms
import timm
from pathlib import Path
class ModelLoader:
def __init__(self, bucket_name: str, model_name: str = "resnet18", num_classes: int = 13):
self.bucket_name = bucket_name
self.model_name = model_name
self.num_classes = num_classes
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create directories if they don't exist
os.makedirs('model', exist_ok=True)
# Download and load model
self.download_latest_model()
self.model = self.load_model()
# Load labels and facts
self.labels = self.get_labels()
self.facts = self.get_facts()
def download_latest_model(self):
"""Download the latest model from S3"""
try:
# Create S3 client using environment variables
s3 = boto3.client(
's3',
aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
region_name=os.getenv('AWS_DEFAULT_REGION')
)
# Download the model
s3.download_file(
self.bucket_name,
'latest/model.pt',
'model/latest_model.pt'
)
print("Successfully downloaded latest model")
except Exception as e:
raise Exception(f"Error downloading model: {str(e)}")
def load_model(self):
"""Load the model with weights"""
model = timm.create_model(
self.model_name,
pretrained=False,
num_classes=self.num_classes
)
# Load state dict
state_dict = torch.load('model/latest_model.pt', map_location=self.device)
# Remove 'model.' prefix from state dict keys if present
if all(k.startswith('model.') for k in state_dict.keys()):
state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
# Load the modified state dict
model.load_state_dict(state_dict)
model.eval()
return model
def get_labels(self):
"""Get class labels"""
return [
"basketball", "boxing", "buildings", "cricket",
"football", "forest", "formula 1 racing", "glacier",
"golf", "hockey", "mountain", "sea", "street"
]
def get_facts(self):
"""Get interesting facts about each class"""
return {
"basketball": "The NBA's three-point line is 23'9\" from the basket.",
"boxing": "The first Olympic boxing competition was held in 1904.",
"buildings": "Modern skyscrapers use advanced materials and engineering to reach incredible heights.",
"cricket": "Cricket is the second most popular sport in the world.",
"football": "A soccer ball must be between 27-28 inches in circumference.",
"forest": "Forests cover about 31% of the world's land surface and are crucial for biodiversity.",
"formula 1 racing": "The fastest lap in a Formula 1 race is 1:15.328.",
"glacier": "Glaciers store about 69% of the world's fresh water.",
"golf": "Golf was first played in Scotland in the 15th century.",
"hockey": "Hockey is the national sport of Canada.",
"mountain": "Mount Everest grows about 4mm higher every year.",
"sea": "The ocean contains 97% of Earth's water and covers 71% of the planet's surface.",
"street": "The oldest known paved road was built in Egypt around 2600 BC."
}
def get_transforms(self):
"""Get image transforms for inference"""
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])