k23064919 commited on
Commit
b727fe2
·
1 Parent(s): 58a6269
Files changed (2) hide show
  1. ui/config.py +2 -2
  2. ui/model_loader.py +12 -6
ui/config.py CHANGED
@@ -4,12 +4,12 @@ CLEARML_TASK_NAME_DEFAULT = "CNN Training (Latest)"
4
  MODEL_CONFIGS = {
5
  "Shallow CNN": {
6
  "description": "modelOne trained on 20 epochs",
7
- "model_type" : "cnn",
8
  "clearml_task_id": "dca82d7c2f404c249f2e5325aaf77207"
9
  },
10
  "Deep CNN": {
11
  "description" : "modleTwo trained on 30 epochs",
12
- "model_type": "cnn",
13
  "clearml_task_id": "c79a6939b46a4882a7fdaee117b1f32e"
14
  }
15
  }
 
4
  MODEL_CONFIGS = {
5
  "Shallow CNN": {
6
  "description": "modelOne trained on 20 epochs",
7
+ "class" : "modelOne",
8
  "clearml_task_id": "dca82d7c2f404c249f2e5325aaf77207"
9
  },
10
  "Deep CNN": {
11
  "description" : "modleTwo trained on 30 epochs",
12
+ "class": "betterCNN",
13
  "clearml_task_id": "c79a6939b46a4882a7fdaee117b1f32e"
14
  }
15
  }
ui/model_loader.py CHANGED
@@ -2,12 +2,17 @@ import torch
2
  import sys
3
  from pathlib import Path
4
  import config
5
- import utils
6
  from clearml import Task
7
  from models.modelOne import modelOne
 
8
 
9
  sys.path.append(str(Path(__file__).parent.parent))
10
 
 
 
 
 
 
11
  MODEL_ARTIFACT_NAME = 'best_model'
12
 
13
  class ModelLoader:
@@ -22,6 +27,7 @@ class ModelLoader:
22
  raise ValueError(f"ClearML configuration not found for model: {modelName}")
23
 
24
  taskID = modelConfig['clearml_task_id']
 
25
 
26
  try:
27
  print(f"Attempting to fetch '{modelName}' from ClearML task: {taskID}")
@@ -29,16 +35,13 @@ class ModelLoader:
29
  task = Task.get_task(task_id=taskID)
30
  print("Available artifacts:", task.artifacts.keys())
31
 
32
- # Fetch the artifact 'model_one.pt'
33
  artifact = task.artifacts.get(MODEL_ARTIFACT_NAME)
34
-
35
  if artifact is None:
36
  raise RuntimeError(
37
  f"Artifact '{MODEL_ARTIFACT_NAME}' not found in ClearML task {taskID}"
38
  )
39
 
40
  modelPath = artifact.get_local_copy()
41
-
42
  if modelPath is None:
43
  raise RuntimeError(
44
  f"Artifact '{MODEL_ARTIFACT_NAME}' could not be downloaded (returned None)"
@@ -46,8 +49,11 @@ class ModelLoader:
46
 
47
  print(f"Weights downloaded to: {modelPath}")
48
 
49
- # Load PyTorch model
50
- model = modelOne(noOfClasses=39)
 
 
 
51
  stateDict = torch.load(modelPath, map_location=self.device)
52
  model.load_state_dict(stateDict)
53
 
 
2
  import sys
3
  from pathlib import Path
4
  import config
 
5
  from clearml import Task
6
  from models.modelOne import modelOne
7
+ from models.modelTwo import BetterCNN
8
 
9
  sys.path.append(str(Path(__file__).parent.parent))
10
 
11
+ MODEL_CLASSES = {
12
+ "modelOne": modelOne,
13
+ "betterCNN": BetterCNN
14
+ }
15
+
16
  MODEL_ARTIFACT_NAME = 'best_model'
17
 
18
  class ModelLoader:
 
27
  raise ValueError(f"ClearML configuration not found for model: {modelName}")
28
 
29
  taskID = modelConfig['clearml_task_id']
30
+ className = modelConfig['class']
31
 
32
  try:
33
  print(f"Attempting to fetch '{modelName}' from ClearML task: {taskID}")
 
35
  task = Task.get_task(task_id=taskID)
36
  print("Available artifacts:", task.artifacts.keys())
37
 
 
38
  artifact = task.artifacts.get(MODEL_ARTIFACT_NAME)
 
39
  if artifact is None:
40
  raise RuntimeError(
41
  f"Artifact '{MODEL_ARTIFACT_NAME}' not found in ClearML task {taskID}"
42
  )
43
 
44
  modelPath = artifact.get_local_copy()
 
45
  if modelPath is None:
46
  raise RuntimeError(
47
  f"Artifact '{MODEL_ARTIFACT_NAME}' could not be downloaded (returned None)"
 
49
 
50
  print(f"Weights downloaded to: {modelPath}")
51
 
52
+ # Load correct model class
53
+ ModelClass = MODEL_CLASSES[className]
54
+ model = ModelClass(noOfClasses=39)
55
+
56
+ # Load weights
57
  stateDict = torch.load(modelPath, map_location=self.device)
58
  model.load_state_dict(stateDict)
59