Conor Brennan (k23064919) commited on
Commit
f622232
·
unverified ·
1 Parent(s): b4cabda

Update model_loader.py

Browse files

remove mock implementation, local and hf model loading

Files changed (1) hide show
  1. ui/model_loader.py +28 -188
ui/model_loader.py CHANGED
@@ -1,207 +1,47 @@
1
- """
2
- Model loading utilities
3
- Handles loading models from different sources: local files, HuggingFace, ClearML
4
- """
5
-
6
  import torch
7
  import sys
8
  from pathlib import Path
 
9
 
10
- # Add parent directory to path to import from models
11
  sys.path.append(str(Path(__file__).parent.parent))
12
 
13
- from models.mock_model import MockPlantDiseaseModel, create_mock_predictions
14
- import config
15
-
16
 
17
  class ModelLoader:
18
- """
19
- Handles loading and managing plant disease models
20
- """
21
-
22
- def __init__(self, use_mock=True):
23
- """
24
- Initialize model loader
25
-
26
- Args:
27
- use_mock: If True, use mock model for development
28
- """
29
- self.use_mock = use_mock
30
- self.model = None
31
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
32
 
33
- def load_model(self, model_name="CNN from Scratch", model_path=None):
34
- """
35
- Load a model based on configuration
36
-
37
- Args:
38
- model_name: Name of the model configuration
39
- model_path: Optional path to model weights
40
-
41
- Returns:
42
- Loaded model
43
- """
44
- if self.use_mock:
45
- print("Loading mock model for development...")
46
- self.model = self._load_mock_model()
47
- else:
48
- print(f"Loading real model: {model_name}")
49
- self.model = self._load_real_model(model_name, model_path)
50
-
51
- self.model.to(self.device)
52
- self.model.eval()
53
- return self.model
54
-
55
- def _load_mock_model(self):
56
- """Load the mock model"""
57
- model = MockPlantDiseaseModel(num_classes=len(config.CLASS_NAMES))
58
- return model
59
-
60
- def _load_real_model(self, model_name, model_path=None):
61
- """
62
- Load a real trained model
63
-
64
- Args:
65
- model_name: Model configuration name
66
- model_path: Path to model weights
67
-
68
- Returns:
69
- Loaded model
70
- """
71
- model_config = config.MODEL_CONFIGS.get(model_name)
72
-
73
- if model_config is None:
74
- raise ValueError(f"Unknown model: {model_name}")
75
-
76
- # TODO: Replace this with your actual model architecture
77
- # For now, using mock model structure
78
- if model_config["model_type"] == "cnn":
79
- model = MockPlantDiseaseModel(num_classes=len(config.CLASS_NAMES))
80
- elif model_config["model_type"] == "resnet18":
81
- # TODO: Load ResNet18 transfer learning model
82
- import torchvision.models as models
83
- model = models.resnet18(pretrained=False)
84
- model.fc = torch.nn.Linear(model.fc.in_features, len(config.CLASS_NAMES))
85
- else:
86
- raise ValueError(f"Unknown model type: {model_config['model_type']}")
87
-
88
- # Load weights if path provided
89
- if model_path:
90
- print(f"Loading weights from {model_path}")
91
- model.load_state_dict(torch.load(model_path, map_location=self.device))
92
-
93
- return model
94
-
95
- def load_from_clearml(self, task_id=None, project_name=None, task_name=None):
96
- """
97
- Load model from ClearML
98
 
99
- Args:
100
- task_id: ClearML task ID (if known)
101
- project_name: ClearML project name
102
- task_name: ClearML task name
103
 
104
- Returns:
105
- Loaded model
106
- """
107
  try:
108
- from clearml import Task, Model
109
 
110
- if task_id:
111
- task = Task.get_task(task_id=task_id)
112
- elif project_name and task_name:
113
- # Get the latest task with this name
114
- task = Task.get_task(
115
- project_name=project_name,
116
- task_name=task_name
117
- )
118
- else:
119
- raise ValueError("Must provide either task_id or (project_name and task_name)")
120
 
121
- # Get the model from the task
122
- model_id = task.models['output'][-1].id if task.models.get('output') else None
123
 
124
- if model_id:
125
- model_obj = Model(model_id)
126
- model_path = model_obj.get_local_copy()
127
-
128
- # Load the model
129
- self.model = self._load_real_model("CNN from Scratch", model_path)
130
- print(f"Model loaded from ClearML task: {task_id or task_name}")
131
-
132
- return self.model
133
- else:
134
- raise ValueError("No output model found in ClearML task")
135
-
136
- except ImportError:
137
- print("ClearML not installed. Install with: pip install clearml")
138
- print("Falling back to mock model")
139
- return self._load_mock_model()
140
  except Exception as e:
141
- print(f"Error loading from ClearML: {e}")
142
- print("Falling back to mock model")
143
- return self._load_mock_model()
144
-
145
- def load_from_huggingface(self, model_id):
146
- """
147
- Load model from HuggingFace Hub
148
-
149
- Args:
150
- model_id: HuggingFace model ID (e.g., "username/model-name")
151
-
152
- Returns:
153
- Loaded model
154
- """
155
  try:
156
- from huggingface_hub import hf_hub_download
157
-
158
- # Download model file
159
- model_path = hf_hub_download(repo_id=model_id, filename="model.pth")
160
-
161
- # Load the model
162
- self.model = self._load_real_model("CNN from Scratch", model_path)
163
- print(f"Model loaded from HuggingFace: {model_id}")
164
-
165
- return self.model
166
-
167
- except ImportError:
168
- print("huggingface_hub not installed. Install with: pip install huggingface_hub")
169
- print("Falling back to mock model")
170
- return self._load_mock_model()
171
  except Exception as e:
172
- print(f"Error loading from HuggingFace: {e}")
173
- print("Falling back to mock model")
174
- return self._load_mock_model()
175
-
176
-
177
- def get_model(use_mock=True, **kwargs):
178
- """
179
- Convenience function to get a loaded model
180
-
181
- Args:
182
- use_mock: Whether to use mock model
183
- **kwargs: Additional arguments for model loading
184
-
185
- Returns:
186
- Loaded model and model loader instance
187
- """
188
- loader = ModelLoader(use_mock=use_mock)
189
- model = loader.load_model(**kwargs)
190
- return model, loader
191
-
192
-
193
- if __name__ == "__main__":
194
- # Test model loading
195
- print("Testing model loading...")
196
-
197
- # Test mock model
198
- print("\n1. Loading mock model:")
199
- model, loader = get_model(use_mock=True)
200
- print(f"Model type: {type(model).__name__}")
201
- print(f"Device: {loader.device}")
202
-
203
- # Test with dummy input
204
- dummy_input = torch.randn(1, 3, 256, 256).to(loader.device)
205
- with torch.no_grad():
206
- output = model(dummy_input)
207
- print(f"Output shape: {output.shape}")
 
 
 
 
 
 
1
  import torch
2
  import sys
3
  from pathlib import Path
4
+ import config
5
 
 
6
  sys.path.append(str(Path(__file__).parent.parent))
7
 
 
 
 
8
 
9
  class ModelLoader:
10
+ def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
11
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ self.modelCache = {}
13
 
14
+ def loadFromClearml(self, modelName):
15
+ modelConfig = config.MODEL_CONFIGS.get(modelName)
16
+
17
+ if not modelConfig or 'clearml_task_id' not in modelConfig:
18
+ raise ValueError(f"ClearML configuration not found for model: {modelName}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ taskID = modelConfig['clearmml_task_id']
21
+ modelType = modelConfig['modelType']
 
 
22
 
 
 
 
23
  try:
24
+ print(f"attemtping to fetch '{modelName}' from clearML task: {taskID}")
25
 
26
+ modelObject = Model(taskID=taskID)
27
+ modelPath = modelObject.get_local_copy()
 
 
 
 
 
 
 
 
28
 
29
+ model = self.loadRealModel(modelName, modelPath, modelType)
 
30
 
31
+ return model
32
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  except Exception as e:
34
+ print(f"Error loading from ClearML for {modelName}: {e}")
35
+ raise RuntimeError(f"Failed to load model from ClearML: {e}")
36
+
37
+ def loadModel(self, modelName) :
38
+ if modelName in self.modelCache:
39
+ return self.modelCache[modelName]
40
+
 
 
 
 
 
 
 
41
  try:
42
+ model = self.loadFromClearml(modelName)
43
+ self.modelCache[modelName] = model
44
+ return model
45
+
 
 
 
 
 
 
 
 
 
 
 
46
  except Exception as e:
47
+ raise RuntimeError(f"Could not load model {modelName}. Check ClearML connection.")