Spaces:
Sleeping
Sleeping
| """ | |
| Model loading utilities | |
| Handles loading models from different sources: local files, HuggingFace, ClearML | |
| """ | |
| import torch | |
| import sys | |
| from pathlib import Path | |
| # Add parent directory to path to import from models | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| from models.mock_model import MockPlantDiseaseModel, create_mock_predictions | |
| import config | |
| class ModelLoader: | |
| """ | |
| Handles loading and managing plant disease models | |
| """ | |
| def __init__(self, use_mock=True): | |
| """ | |
| Initialize model loader | |
| Args: | |
| use_mock: If True, use mock model for development | |
| """ | |
| self.use_mock = use_mock | |
| self.model = None | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(self, model_name="CNN from Scratch", model_path=None): | |
| """ | |
| Load a model based on configuration | |
| Args: | |
| model_name: Name of the model configuration | |
| model_path: Optional path to model weights | |
| Returns: | |
| Loaded model | |
| """ | |
| if self.use_mock: | |
| print("Loading mock model for development...") | |
| self.model = self._load_mock_model() | |
| else: | |
| print(f"Loading real model: {model_name}") | |
| self.model = self._load_real_model(model_name, model_path) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| return self.model | |
| def _load_mock_model(self): | |
| """Load the mock model""" | |
| model = MockPlantDiseaseModel(num_classes=len(config.CLASS_NAMES)) | |
| return model | |
| def _load_real_model(self, model_name, model_path=None): | |
| """ | |
| Load a real trained model | |
| Args: | |
| model_name: Model configuration name | |
| model_path: Path to model weights | |
| Returns: | |
| Loaded model | |
| """ | |
| model_config = config.MODEL_CONFIGS.get(model_name) | |
| if model_config is None: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| # TODO: Replace this with your actual model architecture | |
| # For now, using mock model structure | |
| if model_config["model_type"] == "cnn": | |
| model = MockPlantDiseaseModel(num_classes=len(config.CLASS_NAMES)) | |
| elif model_config["model_type"] == "resnet18": | |
| # TODO: Load ResNet18 transfer learning model | |
| import torchvision.models as models | |
| model = models.resnet18(pretrained=False) | |
| model.fc = torch.nn.Linear(model.fc.in_features, len(config.CLASS_NAMES)) | |
| else: | |
| raise ValueError(f"Unknown model type: {model_config['model_type']}") | |
| # Load weights if path provided | |
| if model_path: | |
| print(f"Loading weights from {model_path}") | |
| model.load_state_dict(torch.load(model_path, map_location=self.device)) | |
| return model | |
| def load_from_clearml(self, task_id=None, project_name=None, task_name=None): | |
| """ | |
| Load model from ClearML | |
| Args: | |
| task_id: ClearML task ID (if known) | |
| project_name: ClearML project name | |
| task_name: ClearML task name | |
| Returns: | |
| Loaded model | |
| """ | |
| try: | |
| from clearml import Task, Model | |
| if task_id: | |
| task = Task.get_task(task_id=task_id) | |
| elif project_name and task_name: | |
| # Get the latest task with this name | |
| task = Task.get_task( | |
| project_name=project_name, | |
| task_name=task_name | |
| ) | |
| else: | |
| raise ValueError("Must provide either task_id or (project_name and task_name)") | |
| # Get the model from the task | |
| model_id = task.models['output'][-1].id if task.models.get('output') else None | |
| if model_id: | |
| model_obj = Model(model_id) | |
| model_path = model_obj.get_local_copy() | |
| # Load the model | |
| self.model = self._load_real_model("CNN from Scratch", model_path) | |
| print(f"Model loaded from ClearML task: {task_id or task_name}") | |
| return self.model | |
| else: | |
| raise ValueError("No output model found in ClearML task") | |
| except ImportError: | |
| print("ClearML not installed. Install with: pip install clearml") | |
| print("Falling back to mock model") | |
| return self._load_mock_model() | |
| except Exception as e: | |
| print(f"Error loading from ClearML: {e}") | |
| print("Falling back to mock model") | |
| return self._load_mock_model() | |
| def load_from_huggingface(self, model_id): | |
| """ | |
| Load model from HuggingFace Hub | |
| Args: | |
| model_id: HuggingFace model ID (e.g., "username/model-name") | |
| Returns: | |
| Loaded model | |
| """ | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| # Download model file | |
| model_path = hf_hub_download(repo_id=model_id, filename="model.pth") | |
| # Load the model | |
| self.model = self._load_real_model("CNN from Scratch", model_path) | |
| print(f"Model loaded from HuggingFace: {model_id}") | |
| return self.model | |
| except ImportError: | |
| print("huggingface_hub not installed. Install with: pip install huggingface_hub") | |
| print("Falling back to mock model") | |
| return self._load_mock_model() | |
| except Exception as e: | |
| print(f"Error loading from HuggingFace: {e}") | |
| print("Falling back to mock model") | |
| return self._load_mock_model() | |
| def get_model(use_mock=True, **kwargs): | |
| """ | |
| Convenience function to get a loaded model | |
| Args: | |
| use_mock: Whether to use mock model | |
| **kwargs: Additional arguments for model loading | |
| Returns: | |
| Loaded model and model loader instance | |
| """ | |
| loader = ModelLoader(use_mock=use_mock) | |
| model = loader.load_model(**kwargs) | |
| return model, loader | |
| if __name__ == "__main__": | |
| # Test model loading | |
| print("Testing model loading...") | |
| # Test mock model | |
| print("\n1. Loading mock model:") | |
| model, loader = get_model(use_mock=True) | |
| print(f"Model type: {type(model).__name__}") | |
| print(f"Device: {loader.device}") | |
| # Test with dummy input | |
| dummy_input = torch.randn(1, 3, 256, 256).to(loader.device) | |
| with torch.no_grad(): | |
| output = model(dummy_input) | |
| print(f"Output shape: {output.shape}") | |