import boto3 import torch from PIL import Image import os import json from torchvision import transforms import timm from pathlib import Path class ModelLoader: def __init__(self, bucket_name: str, model_name: str = "resnet18", num_classes: int = 13): self.bucket_name = bucket_name self.model_name = model_name self.num_classes = num_classes self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create directories if they don't exist os.makedirs('model', exist_ok=True) # Download and load model self.download_latest_model() self.model = self.load_model() # Load labels and facts self.labels = self.get_labels() self.facts = self.get_facts() def download_latest_model(self): """Download the latest model from S3""" try: # Create S3 client using environment variables s3 = boto3.client( 's3', aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), region_name=os.getenv('AWS_DEFAULT_REGION') ) # Download the model s3.download_file( self.bucket_name, 'latest/model.pt', 'model/latest_model.pt' ) print("Successfully downloaded latest model") except Exception as e: raise Exception(f"Error downloading model: {str(e)}") def load_model(self): """Load the model with weights""" model = timm.create_model( self.model_name, pretrained=False, num_classes=self.num_classes ) # Load state dict state_dict = torch.load('model/latest_model.pt', map_location=self.device) # Remove 'model.' prefix from state dict keys if present if all(k.startswith('model.') for k in state_dict.keys()): state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()} # Load the modified state dict model.load_state_dict(state_dict) model.eval() return model def get_labels(self): """Get class labels""" return [ "basketball", "boxing", "buildings", "cricket", "football", "forest", "formula 1 racing", "glacier", "golf", "hockey", "mountain", "sea", "street" ] def get_facts(self): """Get interesting facts about each class""" return { "basketball": "The NBA's three-point line is 23'9\" from the basket.", "boxing": "The first Olympic boxing competition was held in 1904.", "buildings": "Modern skyscrapers use advanced materials and engineering to reach incredible heights.", "cricket": "Cricket is the second most popular sport in the world.", "football": "A soccer ball must be between 27-28 inches in circumference.", "forest": "Forests cover about 31% of the world's land surface and are crucial for biodiversity.", "formula 1 racing": "The fastest lap in a Formula 1 race is 1:15.328.", "glacier": "Glaciers store about 69% of the world's fresh water.", "golf": "Golf was first played in Scotland in the 15th century.", "hockey": "Hockey is the national sport of Canada.", "mountain": "Mount Everest grows about 4mm higher every year.", "sea": "The ocean contains 97% of Earth's water and covers 71% of the planet's surface.", "street": "The oldest known paved road was built in Egypt around 2600 BC." } def get_transforms(self): """Get image transforms for inference""" return transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ])