File size: 4,131 Bytes
2a7e139
 
 
 
 
 
 
8db8e2d
2a7e139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d997b45
 
 
 
 
 
8db8e2d
 
 
2a7e139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5b42e9
 
 
 
 
 
 
 
 
 
2a7e139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import boto3
import torch
from PIL import Image
import os
import json
from torchvision import transforms
import timm
from pathlib import Path

class ModelLoader:
    def __init__(self, bucket_name: str, model_name: str = "resnet18", num_classes: int = 13):
        self.bucket_name = bucket_name
        self.model_name = model_name
        self.num_classes = num_classes
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Create directories if they don't exist
        os.makedirs('model', exist_ok=True)
        
        # Download and load model
        self.download_latest_model()
        self.model = self.load_model()
        
        # Load labels and facts
        self.labels = self.get_labels()
        self.facts = self.get_facts()
        
    def download_latest_model(self):
        """Download the latest model from S3"""
        try:
            # Create S3 client using environment variables
            s3 = boto3.client(
                's3',
                aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
                aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
                region_name=os.getenv('AWS_DEFAULT_REGION')
            )
            
            # Download the model
            s3.download_file(
                self.bucket_name,
                'latest/model.pt',
                'model/latest_model.pt'
            )
            print("Successfully downloaded latest model")
        except Exception as e:
            raise Exception(f"Error downloading model: {str(e)}")
    
    def load_model(self):
        """Load the model with weights"""
        model = timm.create_model(
            self.model_name,
            pretrained=False,
            num_classes=self.num_classes
        )
        
        # Load state dict
        state_dict = torch.load('model/latest_model.pt', map_location=self.device)
        
        # Remove 'model.' prefix from state dict keys if present
        if all(k.startswith('model.') for k in state_dict.keys()):
            state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
        
        # Load the modified state dict
        model.load_state_dict(state_dict)
        model.eval()
        return model
    
    def get_labels(self):
        """Get class labels"""
        return [
            "basketball", "boxing", "buildings", "cricket", 
            "football", "forest", "formula 1 racing", "glacier",
            "golf", "hockey", "mountain", "sea", "street"
        ]
    
    def get_facts(self):
        """Get interesting facts about each class"""
        return {
            "basketball": "The NBA's three-point line is 23'9\" from the basket.",
            "boxing": "The first Olympic boxing competition was held in 1904.",
            "buildings": "Modern skyscrapers use advanced materials and engineering to reach incredible heights.",
            "cricket": "Cricket is the second most popular sport in the world.",
            "football": "A soccer ball must be between 27-28 inches in circumference.",
            "forest": "Forests cover about 31% of the world's land surface and are crucial for biodiversity.",
            "formula 1 racing": "The fastest lap in a Formula 1 race is 1:15.328.",
            "glacier": "Glaciers store about 69% of the world's fresh water.",
            "golf": "Golf was first played in Scotland in the 15th century.",
            "hockey": "Hockey is the national sport of Canada.",
            "mountain": "Mount Everest grows about 4mm higher every year.",
            "sea": "The ocean contains 97% of Earth's water and covers 71% of the planet's surface.",
            "street": "The oldest known paved road was built in Egypt around 2600 BC."
        }
    
    def get_transforms(self):
        """Get image transforms for inference"""
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])