smallGroupProject / ui /model_loader.py
Mert Yerlikaya
Add feature-rich Gradio UI with mock model
505fc99
raw
history blame
6.57 kB
"""
Model loading utilities
Handles loading models from different sources: local files, HuggingFace, ClearML
"""
import torch
import sys
from pathlib import Path
# Add parent directory to path to import from models
sys.path.append(str(Path(__file__).parent.parent))
from models.mock_model import MockPlantDiseaseModel, create_mock_predictions
import config
class ModelLoader:
"""
Handles loading and managing plant disease models
"""
def __init__(self, use_mock=True):
"""
Initialize model loader
Args:
use_mock: If True, use mock model for development
"""
self.use_mock = use_mock
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(self, model_name="CNN from Scratch", model_path=None):
"""
Load a model based on configuration
Args:
model_name: Name of the model configuration
model_path: Optional path to model weights
Returns:
Loaded model
"""
if self.use_mock:
print("Loading mock model for development...")
self.model = self._load_mock_model()
else:
print(f"Loading real model: {model_name}")
self.model = self._load_real_model(model_name, model_path)
self.model.to(self.device)
self.model.eval()
return self.model
def _load_mock_model(self):
"""Load the mock model"""
model = MockPlantDiseaseModel(num_classes=len(config.CLASS_NAMES))
return model
def _load_real_model(self, model_name, model_path=None):
"""
Load a real trained model
Args:
model_name: Model configuration name
model_path: Path to model weights
Returns:
Loaded model
"""
model_config = config.MODEL_CONFIGS.get(model_name)
if model_config is None:
raise ValueError(f"Unknown model: {model_name}")
# TODO: Replace this with your actual model architecture
# For now, using mock model structure
if model_config["model_type"] == "cnn":
model = MockPlantDiseaseModel(num_classes=len(config.CLASS_NAMES))
elif model_config["model_type"] == "resnet18":
# TODO: Load ResNet18 transfer learning model
import torchvision.models as models
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, len(config.CLASS_NAMES))
else:
raise ValueError(f"Unknown model type: {model_config['model_type']}")
# Load weights if path provided
if model_path:
print(f"Loading weights from {model_path}")
model.load_state_dict(torch.load(model_path, map_location=self.device))
return model
def load_from_clearml(self, task_id=None, project_name=None, task_name=None):
"""
Load model from ClearML
Args:
task_id: ClearML task ID (if known)
project_name: ClearML project name
task_name: ClearML task name
Returns:
Loaded model
"""
try:
from clearml import Task, Model
if task_id:
task = Task.get_task(task_id=task_id)
elif project_name and task_name:
# Get the latest task with this name
task = Task.get_task(
project_name=project_name,
task_name=task_name
)
else:
raise ValueError("Must provide either task_id or (project_name and task_name)")
# Get the model from the task
model_id = task.models['output'][-1].id if task.models.get('output') else None
if model_id:
model_obj = Model(model_id)
model_path = model_obj.get_local_copy()
# Load the model
self.model = self._load_real_model("CNN from Scratch", model_path)
print(f"Model loaded from ClearML task: {task_id or task_name}")
return self.model
else:
raise ValueError("No output model found in ClearML task")
except ImportError:
print("ClearML not installed. Install with: pip install clearml")
print("Falling back to mock model")
return self._load_mock_model()
except Exception as e:
print(f"Error loading from ClearML: {e}")
print("Falling back to mock model")
return self._load_mock_model()
def load_from_huggingface(self, model_id):
"""
Load model from HuggingFace Hub
Args:
model_id: HuggingFace model ID (e.g., "username/model-name")
Returns:
Loaded model
"""
try:
from huggingface_hub import hf_hub_download
# Download model file
model_path = hf_hub_download(repo_id=model_id, filename="model.pth")
# Load the model
self.model = self._load_real_model("CNN from Scratch", model_path)
print(f"Model loaded from HuggingFace: {model_id}")
return self.model
except ImportError:
print("huggingface_hub not installed. Install with: pip install huggingface_hub")
print("Falling back to mock model")
return self._load_mock_model()
except Exception as e:
print(f"Error loading from HuggingFace: {e}")
print("Falling back to mock model")
return self._load_mock_model()
def get_model(use_mock=True, **kwargs):
"""
Convenience function to get a loaded model
Args:
use_mock: Whether to use mock model
**kwargs: Additional arguments for model loading
Returns:
Loaded model and model loader instance
"""
loader = ModelLoader(use_mock=use_mock)
model = loader.load_model(**kwargs)
return model, loader
if __name__ == "__main__":
# Test model loading
print("Testing model loading...")
# Test mock model
print("\n1. Loading mock model:")
model, loader = get_model(use_mock=True)
print(f"Model type: {type(model).__name__}")
print(f"Device: {loader.device}")
# Test with dummy input
dummy_input = torch.randn(1, 3, 256, 256).to(loader.device)
with torch.no_grad():
output = model(dummy_input)
print(f"Output shape: {output.shape}")