import torch import sys from pathlib import Path import config from clearml import Task from models.modelOne import modelOne from models.modelTwo import BetterCNN sys.path.append(str(Path(__file__).parent.parent)) MODEL_CLASSES = { "modelOne": modelOne, "betterCNN": BetterCNN } MODEL_ARTIFACT_NAME = 'best_model' class ModelLoader: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.modelCache = {} def loadFromClearml(self, modelName): modelConfig = config.MODEL_CONFIGS.get(modelName) if not modelConfig: raise ValueError(f"ClearML configuration not found for model: {modelName}") taskID = modelConfig['clearml_task_id'] className = modelConfig['class'] try: print(f"Attempting to fetch '{modelName}' from ClearML task: {taskID}") task = Task.get_task(task_id=taskID) print("Available artifacts:", task.artifacts.keys()) artifact = task.artifacts.get(MODEL_ARTIFACT_NAME) if artifact is None: raise RuntimeError( f"Artifact '{MODEL_ARTIFACT_NAME}' not found in ClearML task {taskID}" ) modelPath = artifact.get_local_copy() if modelPath is None: raise RuntimeError( f"Artifact '{MODEL_ARTIFACT_NAME}' could not be downloaded (returned None)" ) print(f"Weights downloaded to: {modelPath}") # Load correct model class ModelClass = MODEL_CLASSES[className] model = ModelClass(noOfClasses=39) # Load weights stateDict = torch.load(modelPath, map_location=self.device) model.load_state_dict(stateDict) model.to(self.device) model.eval() return model except Exception as e: print(f"Error loading from ClearML for {modelName}: {e}") raise RuntimeError(f"Failed to load model from ClearML: {e}") def loadModel(self, modelName): if modelName in self.modelCache: return self.modelCache[modelName] try: model = self.loadFromClearml(modelName) self.modelCache[modelName] = model return model except Exception as e: raise RuntimeError(f"Could not load model {modelName}. Check ClearML connection.")