Spaces:
Sleeping
Sleeping
quickfix
Browse files- ui/model_loader.py +18 -4
ui/model_loader.py
CHANGED
|
@@ -24,14 +24,27 @@ class ModelLoader:
|
|
| 24 |
|
| 25 |
try:
|
| 26 |
print(f"Attempting to fetch '{modelName}' from ClearML task: {taskID}")
|
| 27 |
-
|
| 28 |
task = Task.get_task(task_id=taskID)
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
modelPath = modelObject.get_local_copy()
|
| 33 |
print(f"Weights downloaded to: {modelPath}")
|
| 34 |
|
|
|
|
| 35 |
model = modelOne(noOfClasses=39)
|
| 36 |
stateDict = torch.load(modelPath, map_location=self.device)
|
| 37 |
model.load_state_dict(stateDict)
|
|
@@ -40,10 +53,11 @@ class ModelLoader:
|
|
| 40 |
model.eval()
|
| 41 |
|
| 42 |
return model
|
| 43 |
-
|
| 44 |
except Exception as e:
|
| 45 |
print(f"Error loading from ClearML for {modelName}: {e}")
|
| 46 |
raise RuntimeError(f"Failed to load model from ClearML: {e}")
|
|
|
|
| 47 |
|
| 48 |
def loadModel(self, modelName):
|
| 49 |
if modelName in self.modelCache:
|
|
|
|
| 24 |
|
| 25 |
try:
|
| 26 |
print(f"Attempting to fetch '{modelName}' from ClearML task: {taskID}")
|
| 27 |
+
|
| 28 |
task = Task.get_task(task_id=taskID)
|
| 29 |
|
| 30 |
+
# Fetch the artifact 'model_one.pt'
|
| 31 |
+
artifact = task.artifacts.get(MODEL_ARTIFACT_NAME)
|
| 32 |
+
|
| 33 |
+
if artifact is None:
|
| 34 |
+
raise RuntimeError(
|
| 35 |
+
f"Artifact '{MODEL_ARTIFACT_NAME}' not found in ClearML task {taskID}"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
modelPath = artifact.get_local_copy()
|
| 39 |
+
|
| 40 |
+
if modelPath is None:
|
| 41 |
+
raise RuntimeError(
|
| 42 |
+
f"Artifact '{MODEL_ARTIFACT_NAME}' could not be downloaded (returned None)"
|
| 43 |
+
)
|
| 44 |
|
|
|
|
| 45 |
print(f"Weights downloaded to: {modelPath}")
|
| 46 |
|
| 47 |
+
# Load PyTorch model
|
| 48 |
model = modelOne(noOfClasses=39)
|
| 49 |
stateDict = torch.load(modelPath, map_location=self.device)
|
| 50 |
model.load_state_dict(stateDict)
|
|
|
|
| 53 |
model.eval()
|
| 54 |
|
| 55 |
return model
|
| 56 |
+
|
| 57 |
except Exception as e:
|
| 58 |
print(f"Error loading from ClearML for {modelName}: {e}")
|
| 59 |
raise RuntimeError(f"Failed to load model from ClearML: {e}")
|
| 60 |
+
|
| 61 |
|
| 62 |
def loadModel(self, modelName):
|
| 63 |
if modelName in self.modelCache:
|