File size: 6,572 Bytes
505fc99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""
Model loading utilities
Handles loading models from different sources: local files, HuggingFace, ClearML
"""

import torch
import sys
from pathlib import Path

# Add parent directory to path to import from models
sys.path.append(str(Path(__file__).parent.parent))

from models.mock_model import MockPlantDiseaseModel, create_mock_predictions
import config


class ModelLoader:
    """
    Handles loading and managing plant disease models
    """

    def __init__(self, use_mock=True):
        """
        Initialize model loader

        Args:
            use_mock: If True, use mock model for development
        """
        self.use_mock = use_mock
        self.model = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def load_model(self, model_name="CNN from Scratch", model_path=None):
        """
        Load a model based on configuration

        Args:
            model_name: Name of the model configuration
            model_path: Optional path to model weights

        Returns:
            Loaded model
        """
        if self.use_mock:
            print("Loading mock model for development...")
            self.model = self._load_mock_model()
        else:
            print(f"Loading real model: {model_name}")
            self.model = self._load_real_model(model_name, model_path)

        self.model.to(self.device)
        self.model.eval()
        return self.model

    def _load_mock_model(self):
        """Load the mock model"""
        model = MockPlantDiseaseModel(num_classes=len(config.CLASS_NAMES))
        return model

    def _load_real_model(self, model_name, model_path=None):
        """
        Load a real trained model

        Args:
            model_name: Model configuration name
            model_path: Path to model weights

        Returns:
            Loaded model
        """
        model_config = config.MODEL_CONFIGS.get(model_name)

        if model_config is None:
            raise ValueError(f"Unknown model: {model_name}")

        # TODO: Replace this with your actual model architecture
        # For now, using mock model structure
        if model_config["model_type"] == "cnn":
            model = MockPlantDiseaseModel(num_classes=len(config.CLASS_NAMES))
        elif model_config["model_type"] == "resnet18":
            # TODO: Load ResNet18 transfer learning model
            import torchvision.models as models
            model = models.resnet18(pretrained=False)
            model.fc = torch.nn.Linear(model.fc.in_features, len(config.CLASS_NAMES))
        else:
            raise ValueError(f"Unknown model type: {model_config['model_type']}")

        # Load weights if path provided
        if model_path:
            print(f"Loading weights from {model_path}")
            model.load_state_dict(torch.load(model_path, map_location=self.device))

        return model

    def load_from_clearml(self, task_id=None, project_name=None, task_name=None):
        """
        Load model from ClearML

        Args:
            task_id: ClearML task ID (if known)
            project_name: ClearML project name
            task_name: ClearML task name

        Returns:
            Loaded model
        """
        try:
            from clearml import Task, Model

            if task_id:
                task = Task.get_task(task_id=task_id)
            elif project_name and task_name:
                # Get the latest task with this name
                task = Task.get_task(
                    project_name=project_name,
                    task_name=task_name
                )
            else:
                raise ValueError("Must provide either task_id or (project_name and task_name)")

            # Get the model from the task
            model_id = task.models['output'][-1].id if task.models.get('output') else None

            if model_id:
                model_obj = Model(model_id)
                model_path = model_obj.get_local_copy()

                # Load the model
                self.model = self._load_real_model("CNN from Scratch", model_path)
                print(f"Model loaded from ClearML task: {task_id or task_name}")

                return self.model
            else:
                raise ValueError("No output model found in ClearML task")

        except ImportError:
            print("ClearML not installed. Install with: pip install clearml")
            print("Falling back to mock model")
            return self._load_mock_model()
        except Exception as e:
            print(f"Error loading from ClearML: {e}")
            print("Falling back to mock model")
            return self._load_mock_model()

    def load_from_huggingface(self, model_id):
        """
        Load model from HuggingFace Hub

        Args:
            model_id: HuggingFace model ID (e.g., "username/model-name")

        Returns:
            Loaded model
        """
        try:
            from huggingface_hub import hf_hub_download

            # Download model file
            model_path = hf_hub_download(repo_id=model_id, filename="model.pth")

            # Load the model
            self.model = self._load_real_model("CNN from Scratch", model_path)
            print(f"Model loaded from HuggingFace: {model_id}")

            return self.model

        except ImportError:
            print("huggingface_hub not installed. Install with: pip install huggingface_hub")
            print("Falling back to mock model")
            return self._load_mock_model()
        except Exception as e:
            print(f"Error loading from HuggingFace: {e}")
            print("Falling back to mock model")
            return self._load_mock_model()


def get_model(use_mock=True, **kwargs):
    """
    Convenience function to get a loaded model

    Args:
        use_mock: Whether to use mock model
        **kwargs: Additional arguments for model loading

    Returns:
        Loaded model and model loader instance
    """
    loader = ModelLoader(use_mock=use_mock)
    model = loader.load_model(**kwargs)
    return model, loader


if __name__ == "__main__":
    # Test model loading
    print("Testing model loading...")

    # Test mock model
    print("\n1. Loading mock model:")
    model, loader = get_model(use_mock=True)
    print(f"Model type: {type(model).__name__}")
    print(f"Device: {loader.device}")

    # Test with dummy input
    dummy_input = torch.randn(1, 3, 256, 256).to(loader.device)
    with torch.no_grad():
        output = model(dummy_input)
    print(f"Output shape: {output.shape}")