k23064919 commited on
Commit
fe2fd36
·
1 Parent(s): 2e1ef59
Files changed (1) hide show
  1. ui/model_loader.py +18 -4
ui/model_loader.py CHANGED
@@ -24,14 +24,27 @@ class ModelLoader:
24
 
25
  try:
26
  print(f"Attempting to fetch '{modelName}' from ClearML task: {taskID}")
27
-
28
  task = Task.get_task(task_id=taskID)
29
 
30
- modelObject = task.models["output"][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- modelPath = modelObject.get_local_copy()
33
  print(f"Weights downloaded to: {modelPath}")
34
 
 
35
  model = modelOne(noOfClasses=39)
36
  stateDict = torch.load(modelPath, map_location=self.device)
37
  model.load_state_dict(stateDict)
@@ -40,10 +53,11 @@ class ModelLoader:
40
  model.eval()
41
 
42
  return model
43
-
44
  except Exception as e:
45
  print(f"Error loading from ClearML for {modelName}: {e}")
46
  raise RuntimeError(f"Failed to load model from ClearML: {e}")
 
47
 
48
  def loadModel(self, modelName):
49
  if modelName in self.modelCache:
 
24
 
25
  try:
26
  print(f"Attempting to fetch '{modelName}' from ClearML task: {taskID}")
27
+
28
  task = Task.get_task(task_id=taskID)
29
 
30
+ # Fetch the artifact 'model_one.pt'
31
+ artifact = task.artifacts.get(MODEL_ARTIFACT_NAME)
32
+
33
+ if artifact is None:
34
+ raise RuntimeError(
35
+ f"Artifact '{MODEL_ARTIFACT_NAME}' not found in ClearML task {taskID}"
36
+ )
37
+
38
+ modelPath = artifact.get_local_copy()
39
+
40
+ if modelPath is None:
41
+ raise RuntimeError(
42
+ f"Artifact '{MODEL_ARTIFACT_NAME}' could not be downloaded (returned None)"
43
+ )
44
 
 
45
  print(f"Weights downloaded to: {modelPath}")
46
 
47
+ # Load PyTorch model
48
  model = modelOne(noOfClasses=39)
49
  stateDict = torch.load(modelPath, map_location=self.device)
50
  model.load_state_dict(stateDict)
 
53
  model.eval()
54
 
55
  return model
56
+
57
  except Exception as e:
58
  print(f"Error loading from ClearML for {modelName}: {e}")
59
  raise RuntimeError(f"Failed to load model from ClearML: {e}")
60
+
61
 
62
  def loadModel(self, modelName):
63
  if modelName in self.modelCache: