k23064919 commited on
Commit
5dc8013
·
1 Parent(s): fec2e27

add model loading to modelLoader

Browse files
Files changed (2) hide show
  1. ui/config.py +0 -5
  2. ui/model_loader.py +9 -1
ui/config.py CHANGED
@@ -6,10 +6,5 @@ MODEL_CONFIGS = {
6
  "description": "Custom CNN model trained from scratch",
7
  "model_type": "cnn",
8
  "clearml_task_id": "fe14662da63d45bf9208fdf9856d2fcc"
9
- },
10
- "Transfer Learning (ResNet18)": {
11
- "description": "Fine-tuned ResNet18 model",
12
- "model_type": "resnet18",
13
- "clearml_task_id": "SET_ME_TO_YOUR_RESNET_TASK_ID"
14
  }
15
  }
 
6
  "description": "Custom CNN model trained from scratch",
7
  "model_type": "cnn",
8
  "clearml_task_id": "fe14662da63d45bf9208fdf9856d2fcc"
 
 
 
 
 
9
  }
10
  }
ui/model_loader.py CHANGED
@@ -3,6 +3,7 @@ import sys
3
  from pathlib import Path
4
  import config
5
  from clearml import Model
 
6
 
7
  sys.path.append(str(Path(__file__).parent.parent))
8
 
@@ -26,8 +27,15 @@ class ModelLoader:
26
 
27
  modelObject = Model(model_id=taskID)
28
  modelPath = modelObject.get_local_copy()
 
29
 
30
- model = self.loadRealModel(modelName, modelPath, modelType)
 
 
 
 
 
 
31
 
32
  return model
33
 
 
3
  from pathlib import Path
4
  import config
5
  from clearml import Model
6
+ from ..models import modelOne
7
 
8
  sys.path.append(str(Path(__file__).parent.parent))
9
 
 
27
 
28
  modelObject = Model(model_id=taskID)
29
  modelPath = modelObject.get_local_copy()
30
+ print(f"Weights downloaded to local path: {modelPath}")
31
 
32
+ model = self.modelOne()
33
+
34
+ stateDict = torch.load(modelPath, map_location=self.device)
35
+ modelObject.load_state_dict(stateDict)
36
+
37
+ model.to(self.device)
38
+ model.eval()
39
 
40
  return model
41