""" Model loading utilities Handles loading models from different sources: local files, HuggingFace, ClearML """ import torch import sys from pathlib import Path # Add parent directory to path to import from models sys.path.append(str(Path(__file__).parent.parent)) from models.mock_model import MockPlantDiseaseModel, create_mock_predictions import config class ModelLoader: """ Handles loading and managing plant disease models """ def __init__(self, use_mock=True): """ Initialize model loader Args: use_mock: If True, use mock model for development """ self.use_mock = use_mock self.model = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(self, model_name="CNN from Scratch", model_path=None): """ Load a model based on configuration Args: model_name: Name of the model configuration model_path: Optional path to model weights Returns: Loaded model """ if self.use_mock: print("Loading mock model for development...") self.model = self._load_mock_model() else: print(f"Loading real model: {model_name}") self.model = self._load_real_model(model_name, model_path) self.model.to(self.device) self.model.eval() return self.model def _load_mock_model(self): """Load the mock model""" model = MockPlantDiseaseModel(num_classes=len(config.CLASS_NAMES)) return model def _load_real_model(self, model_name, model_path=None): """ Load a real trained model Args: model_name: Model configuration name model_path: Path to model weights Returns: Loaded model """ model_config = config.MODEL_CONFIGS.get(model_name) if model_config is None: raise ValueError(f"Unknown model: {model_name}") # TODO: Replace this with your actual model architecture # For now, using mock model structure if model_config["model_type"] == "cnn": model = MockPlantDiseaseModel(num_classes=len(config.CLASS_NAMES)) elif model_config["model_type"] == "resnet18": # TODO: Load ResNet18 transfer learning model import torchvision.models as models model = models.resnet18(pretrained=False) model.fc = torch.nn.Linear(model.fc.in_features, len(config.CLASS_NAMES)) else: raise ValueError(f"Unknown model type: {model_config['model_type']}") # Load weights if path provided if model_path: print(f"Loading weights from {model_path}") model.load_state_dict(torch.load(model_path, map_location=self.device)) return model def load_from_clearml(self, task_id=None, project_name=None, task_name=None): """ Load model from ClearML Args: task_id: ClearML task ID (if known) project_name: ClearML project name task_name: ClearML task name Returns: Loaded model """ try: from clearml import Task, Model if task_id: task = Task.get_task(task_id=task_id) elif project_name and task_name: # Get the latest task with this name task = Task.get_task( project_name=project_name, task_name=task_name ) else: raise ValueError("Must provide either task_id or (project_name and task_name)") # Get the model from the task model_id = task.models['output'][-1].id if task.models.get('output') else None if model_id: model_obj = Model(model_id) model_path = model_obj.get_local_copy() # Load the model self.model = self._load_real_model("CNN from Scratch", model_path) print(f"Model loaded from ClearML task: {task_id or task_name}") return self.model else: raise ValueError("No output model found in ClearML task") except ImportError: print("ClearML not installed. Install with: pip install clearml") print("Falling back to mock model") return self._load_mock_model() except Exception as e: print(f"Error loading from ClearML: {e}") print("Falling back to mock model") return self._load_mock_model() def load_from_huggingface(self, model_id): """ Load model from HuggingFace Hub Args: model_id: HuggingFace model ID (e.g., "username/model-name") Returns: Loaded model """ try: from huggingface_hub import hf_hub_download # Download model file model_path = hf_hub_download(repo_id=model_id, filename="model.pth") # Load the model self.model = self._load_real_model("CNN from Scratch", model_path) print(f"Model loaded from HuggingFace: {model_id}") return self.model except ImportError: print("huggingface_hub not installed. Install with: pip install huggingface_hub") print("Falling back to mock model") return self._load_mock_model() except Exception as e: print(f"Error loading from HuggingFace: {e}") print("Falling back to mock model") return self._load_mock_model() def get_model(use_mock=True, **kwargs): """ Convenience function to get a loaded model Args: use_mock: Whether to use mock model **kwargs: Additional arguments for model loading Returns: Loaded model and model loader instance """ loader = ModelLoader(use_mock=use_mock) model = loader.load_model(**kwargs) return model, loader if __name__ == "__main__": # Test model loading print("Testing model loading...") # Test mock model print("\n1. Loading mock model:") model, loader = get_model(use_mock=True) print(f"Model type: {type(model).__name__}") print(f"Device: {loader.device}") # Test with dummy input dummy_input = torch.randn(1, 3, 256, 256).to(loader.device) with torch.no_grad(): output = model(dummy_input) print(f"Output shape: {output.shape}")