| | |
| | """Model management module for Bean Vision project.""" |
| |
|
| | from pathlib import Path |
| | from typing import Optional, Union |
| | import torch |
| | import torch.nn as nn |
| | import torchvision |
| | from torchvision.models.detection import maskrcnn_resnet50_fpn |
| |
|
| | from bean_vision.config import BeanVisionConfig |
| | from bean_vision.utils.logging import get_logger |
| | from bean_vision.utils.misc import ModelError, validate_device, safe_load_model_checkpoint |
| |
|
| |
|
| | class BeanModel: |
| | """Bean detection model wrapper with utilities.""" |
| | |
| | def __init__(self, config: BeanVisionConfig): |
| | self.config = config |
| | self.logger = get_logger(self.__class__.__name__) |
| | self.device = validate_device(config.model.device) |
| | self.model: Optional[nn.Module] = None |
| | |
| | self.model = self.create_model() |
| | |
| | def create_model(self) -> nn.Module: |
| | """Create MaskR-CNN model with modified heads.""" |
| | try: |
| | |
| | model = maskrcnn_resnet50_fpn( |
| | weights="DEFAULT", |
| | rpn_pre_nms_top_n_train=6000, |
| | rpn_pre_nms_top_n_test=3000, |
| | rpn_post_nms_top_n_train=4000, |
| | rpn_post_nms_top_n_test=2000, |
| | box_detections_per_img=1000, |
| | box_score_thresh=0.05 |
| | ) |
| | |
| | |
| | in_features = model.roi_heads.box_predictor.cls_score.in_features |
| | model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor( |
| | in_features, self.config.model.num_classes |
| | ) |
| | |
| | |
| | in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels |
| | hidden_layer = 256 |
| | model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor( |
| | in_features_mask, hidden_layer, self.config.model.num_classes |
| | ) |
| | |
| | |
| | model.to(self.device) |
| | |
| | |
| | num_params = sum(p.numel() for p in model.parameters()) |
| | self.logger.info(f"Model created with {num_params:,} parameters on {self.device}") |
| | |
| | self.model = model |
| | return model |
| | |
| | except Exception as e: |
| | raise ModelError(f"Failed to create model: {e}") |
| | |
| | def load_checkpoint(self, checkpoint_path: Union[str, Path], use_pretrained: bool = False) -> nn.Module: |
| | """Load model from checkpoint.""" |
| | if self.model is None: |
| | self.model = self.create_model() |
| | |
| | if use_pretrained: |
| | self.logger.info("Using pretrained base model (no custom training)") |
| | return self.model |
| | |
| | try: |
| | checkpoint = safe_load_model_checkpoint(checkpoint_path, self.device) |
| | |
| | if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
| | self.model.load_state_dict(checkpoint['model_state_dict']) |
| | self.logger.info(f"Loaded model checkpoint from {checkpoint_path}") |
| | else: |
| | self.model.load_state_dict(checkpoint) |
| | self.logger.info(f"Loaded model state dict from {checkpoint_path}") |
| | |
| | return self.model |
| | |
| | except Exception as e: |
| | raise ModelError(f"Failed to load checkpoint: {e}") |
| | |
| | def save_checkpoint(self, |
| | filepath: Union[str, Path], |
| | optimizer: Optional[torch.optim.Optimizer] = None, |
| | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, |
| | epoch: int = 0, |
| | best_metric: float = 0.0, |
| | train_loss: float = 0.0) -> None: |
| | """Save model checkpoint with metadata.""" |
| | if self.model is None: |
| | raise ModelError("No model to save") |
| | |
| | try: |
| | checkpoint = { |
| | 'epoch': epoch, |
| | 'model_state_dict': self.model.state_dict(), |
| | 'best_metric': best_metric, |
| | 'train_loss': train_loss, |
| | 'config': { |
| | 'num_classes': self.config.model.num_classes, |
| | 'device': str(self.device) |
| | } |
| | } |
| | |
| | if optimizer: |
| | checkpoint['optimizer_state_dict'] = optimizer.state_dict() |
| | if scheduler: |
| | checkpoint['scheduler_state_dict'] = scheduler.state_dict() |
| | |
| | torch.save(checkpoint, filepath) |
| | self.logger.info(f"Checkpoint saved: {filepath}") |
| | |
| | except Exception as e: |
| | raise ModelError(f"Failed to save checkpoint: {e}") |
| | |
| | def get_model_info(self) -> dict: |
| | """Get model information and statistics.""" |
| | if self.model is None: |
| | raise ModelError("No model loaded") |
| | |
| | total_params = sum(p.numel() for p in self.model.parameters()) |
| | trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) |
| | |
| | return { |
| | 'architecture': 'MaskR-CNN ResNet50 FPN', |
| | 'num_classes': self.config.model.num_classes, |
| | 'device': str(self.device), |
| | 'total_parameters': total_params, |
| | 'trainable_parameters': trainable_params, |
| | 'parameter_size_mb': total_params * 4 / (1024**2), |
| | 'mode': 'training' if self.model.training else 'evaluation' |
| | } |
| | |
| | def set_training_mode(self, mode: bool = True) -> None: |
| | """Set model training mode.""" |
| | if self.model is None: |
| | raise ModelError("No model loaded") |
| | |
| | if mode: |
| | self.model.train() |
| | self.logger.debug("Model set to training mode") |
| | else: |
| | self.model.eval() |
| | self.logger.debug("Model set to evaluation mode") |
| | |
| | def freeze_backbone(self, freeze: bool = True) -> None: |
| | """Freeze/unfreeze model backbone.""" |
| | if self.model is None: |
| | raise ModelError("No model loaded") |
| | |
| | for param in self.model.backbone.parameters(): |
| | param.requires_grad = not freeze |
| | |
| | status = "frozen" if freeze else "unfrozen" |
| | self.logger.info(f"Backbone {status}") |
| | |
| | def get_optimizer(self) -> torch.optim.Optimizer: |
| | """Get SGD optimizer configured from config.""" |
| | if self.model is None: |
| | raise ModelError("No model loaded") |
| | |
| | return torch.optim.SGD( |
| | self.model.parameters(), |
| | lr=self.config.training.learning_rate, |
| | momentum=self.config.training.momentum, |
| | weight_decay=self.config.training.weight_decay |
| | ) |
| | |
| | def get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler: |
| | """Get learning rate scheduler configured from config.""" |
| | return torch.optim.lr_scheduler.StepLR( |
| | optimizer, |
| | step_size=self.config.training.lr_scheduler_step, |
| | gamma=self.config.training.lr_scheduler_gamma |
| | ) |
| | |
| | def count_parameters(self, trainable_only: bool = False) -> int: |
| | """Count model parameters.""" |
| | if self.model is None: |
| | raise ModelError("No model loaded") |
| | |
| | if trainable_only: |
| | return sum(p.numel() for p in self.model.parameters() if p.requires_grad) |
| | else: |
| | return sum(p.numel() for p in self.model.parameters()) |
| | |
| | def get_device(self) -> torch.device: |
| | """Get model device.""" |
| | return self.device |
| | |
| | def to_device(self, device: Union[str, torch.device]) -> None: |
| | """Move model to specified device.""" |
| | if self.model is None: |
| | raise ModelError("No model loaded") |
| | |
| | self.device = validate_device(device) |
| | self.model.to(self.device) |
| | self.logger.info(f"Model moved to {self.device}") |
| |
|
| |
|
| | def create_model(config: BeanVisionConfig) -> BeanModel: |
| | """Factory function to create BeanModel instance.""" |
| | return BeanModel(config) |