Spaces:
Runtime error
Runtime error
| import boto3 | |
| import torch | |
| from PIL import Image | |
| import os | |
| import json | |
| from torchvision import transforms | |
| import timm | |
| from pathlib import Path | |
| class ModelLoader: | |
| def __init__(self, bucket_name: str, model_name: str = "resnet18", num_classes: int = 13): | |
| self.bucket_name = bucket_name | |
| self.model_name = model_name | |
| self.num_classes = num_classes | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Create directories if they don't exist | |
| os.makedirs('model', exist_ok=True) | |
| # Download and load model | |
| self.download_latest_model() | |
| self.model = self.load_model() | |
| # Load labels and facts | |
| self.labels = self.get_labels() | |
| self.facts = self.get_facts() | |
| def download_latest_model(self): | |
| """Download the latest model from S3""" | |
| try: | |
| # Create S3 client using environment variables | |
| s3 = boto3.client( | |
| 's3', | |
| aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), | |
| aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), | |
| region_name=os.getenv('AWS_DEFAULT_REGION') | |
| ) | |
| # Download the model | |
| s3.download_file( | |
| self.bucket_name, | |
| 'latest/model.pt', | |
| 'model/latest_model.pt' | |
| ) | |
| print("Successfully downloaded latest model") | |
| except Exception as e: | |
| raise Exception(f"Error downloading model: {str(e)}") | |
| def load_model(self): | |
| """Load the model with weights""" | |
| model = timm.create_model( | |
| self.model_name, | |
| pretrained=False, | |
| num_classes=self.num_classes | |
| ) | |
| # Load state dict | |
| state_dict = torch.load('model/latest_model.pt', map_location=self.device) | |
| # Remove 'model.' prefix from state dict keys if present | |
| if all(k.startswith('model.') for k in state_dict.keys()): | |
| state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()} | |
| # Load the modified state dict | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| def get_labels(self): | |
| """Get class labels""" | |
| return [ | |
| "basketball", "boxing", "buildings", "cricket", | |
| "football", "forest", "formula 1 racing", "glacier", | |
| "golf", "hockey", "mountain", "sea", "street" | |
| ] | |
| def get_facts(self): | |
| """Get interesting facts about each class""" | |
| return { | |
| "basketball": "The NBA's three-point line is 23'9\" from the basket.", | |
| "boxing": "The first Olympic boxing competition was held in 1904.", | |
| "buildings": "Modern skyscrapers use advanced materials and engineering to reach incredible heights.", | |
| "cricket": "Cricket is the second most popular sport in the world.", | |
| "football": "A soccer ball must be between 27-28 inches in circumference.", | |
| "forest": "Forests cover about 31% of the world's land surface and are crucial for biodiversity.", | |
| "formula 1 racing": "The fastest lap in a Formula 1 race is 1:15.328.", | |
| "glacier": "Glaciers store about 69% of the world's fresh water.", | |
| "golf": "Golf was first played in Scotland in the 15th century.", | |
| "hockey": "Hockey is the national sport of Canada.", | |
| "mountain": "Mount Everest grows about 4mm higher every year.", | |
| "sea": "The ocean contains 97% of Earth's water and covers 71% of the planet's surface.", | |
| "street": "The oldest known paved road was built in Egypt around 2600 BC." | |
| } | |
| def get_transforms(self): | |
| """Get image transforms for inference""" | |
| return transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) |