Spaces:
Sleeping
Sleeping
File size: 6,572 Bytes
505fc99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
"""
Model loading utilities
Handles loading models from different sources: local files, HuggingFace, ClearML
"""
import torch
import sys
from pathlib import Path
# Add parent directory to path to import from models
sys.path.append(str(Path(__file__).parent.parent))
from models.mock_model import MockPlantDiseaseModel, create_mock_predictions
import config
class ModelLoader:
"""
Handles loading and managing plant disease models
"""
def __init__(self, use_mock=True):
"""
Initialize model loader
Args:
use_mock: If True, use mock model for development
"""
self.use_mock = use_mock
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(self, model_name="CNN from Scratch", model_path=None):
"""
Load a model based on configuration
Args:
model_name: Name of the model configuration
model_path: Optional path to model weights
Returns:
Loaded model
"""
if self.use_mock:
print("Loading mock model for development...")
self.model = self._load_mock_model()
else:
print(f"Loading real model: {model_name}")
self.model = self._load_real_model(model_name, model_path)
self.model.to(self.device)
self.model.eval()
return self.model
def _load_mock_model(self):
"""Load the mock model"""
model = MockPlantDiseaseModel(num_classes=len(config.CLASS_NAMES))
return model
def _load_real_model(self, model_name, model_path=None):
"""
Load a real trained model
Args:
model_name: Model configuration name
model_path: Path to model weights
Returns:
Loaded model
"""
model_config = config.MODEL_CONFIGS.get(model_name)
if model_config is None:
raise ValueError(f"Unknown model: {model_name}")
# TODO: Replace this with your actual model architecture
# For now, using mock model structure
if model_config["model_type"] == "cnn":
model = MockPlantDiseaseModel(num_classes=len(config.CLASS_NAMES))
elif model_config["model_type"] == "resnet18":
# TODO: Load ResNet18 transfer learning model
import torchvision.models as models
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, len(config.CLASS_NAMES))
else:
raise ValueError(f"Unknown model type: {model_config['model_type']}")
# Load weights if path provided
if model_path:
print(f"Loading weights from {model_path}")
model.load_state_dict(torch.load(model_path, map_location=self.device))
return model
def load_from_clearml(self, task_id=None, project_name=None, task_name=None):
"""
Load model from ClearML
Args:
task_id: ClearML task ID (if known)
project_name: ClearML project name
task_name: ClearML task name
Returns:
Loaded model
"""
try:
from clearml import Task, Model
if task_id:
task = Task.get_task(task_id=task_id)
elif project_name and task_name:
# Get the latest task with this name
task = Task.get_task(
project_name=project_name,
task_name=task_name
)
else:
raise ValueError("Must provide either task_id or (project_name and task_name)")
# Get the model from the task
model_id = task.models['output'][-1].id if task.models.get('output') else None
if model_id:
model_obj = Model(model_id)
model_path = model_obj.get_local_copy()
# Load the model
self.model = self._load_real_model("CNN from Scratch", model_path)
print(f"Model loaded from ClearML task: {task_id or task_name}")
return self.model
else:
raise ValueError("No output model found in ClearML task")
except ImportError:
print("ClearML not installed. Install with: pip install clearml")
print("Falling back to mock model")
return self._load_mock_model()
except Exception as e:
print(f"Error loading from ClearML: {e}")
print("Falling back to mock model")
return self._load_mock_model()
def load_from_huggingface(self, model_id):
"""
Load model from HuggingFace Hub
Args:
model_id: HuggingFace model ID (e.g., "username/model-name")
Returns:
Loaded model
"""
try:
from huggingface_hub import hf_hub_download
# Download model file
model_path = hf_hub_download(repo_id=model_id, filename="model.pth")
# Load the model
self.model = self._load_real_model("CNN from Scratch", model_path)
print(f"Model loaded from HuggingFace: {model_id}")
return self.model
except ImportError:
print("huggingface_hub not installed. Install with: pip install huggingface_hub")
print("Falling back to mock model")
return self._load_mock_model()
except Exception as e:
print(f"Error loading from HuggingFace: {e}")
print("Falling back to mock model")
return self._load_mock_model()
def get_model(use_mock=True, **kwargs):
"""
Convenience function to get a loaded model
Args:
use_mock: Whether to use mock model
**kwargs: Additional arguments for model loading
Returns:
Loaded model and model loader instance
"""
loader = ModelLoader(use_mock=use_mock)
model = loader.load_model(**kwargs)
return model, loader
if __name__ == "__main__":
# Test model loading
print("Testing model loading...")
# Test mock model
print("\n1. Loading mock model:")
model, loader = get_model(use_mock=True)
print(f"Model type: {type(model).__name__}")
print(f"Device: {loader.device}")
# Test with dummy input
dummy_input = torch.randn(1, 3, 256, 256).to(loader.device)
with torch.no_grad():
output = model(dummy_input)
print(f"Output shape: {output.shape}")
|