#!/usr/bin/env python """Model management module for Bean Vision project.""" from pathlib import Path from typing import Optional, Union import torch import torch.nn as nn import torchvision from torchvision.models.detection import maskrcnn_resnet50_fpn from bean_vision.config import BeanVisionConfig from bean_vision.utils.logging import get_logger from bean_vision.utils.misc import ModelError, validate_device, safe_load_model_checkpoint class BeanModel: """Bean detection model wrapper with utilities.""" def __init__(self, config: BeanVisionConfig): self.config = config self.logger = get_logger(self.__class__.__name__) self.device = validate_device(config.model.device) self.model: Optional[nn.Module] = None # Initialize model self.model = self.create_model() def create_model(self) -> nn.Module: """Create MaskR-CNN model with modified heads.""" try: # Load pre-trained model with increased detection limits for high bean counts model = maskrcnn_resnet50_fpn( weights="DEFAULT", rpn_pre_nms_top_n_train=6000, # Increased from 2000 rpn_pre_nms_top_n_test=3000, # Increased from 1000 rpn_post_nms_top_n_train=4000, # Increased from 2000 rpn_post_nms_top_n_test=2000, # Increased from 1000 box_detections_per_img=1000, # Increased from 100 box_score_thresh=0.05 # Lower threshold for more detections ) # Replace classifier head in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor( in_features, self.config.model.num_classes ) # Replace mask predictor head in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels hidden_layer = 256 model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor( in_features_mask, hidden_layer, self.config.model.num_classes ) # Move to device model.to(self.device) # Count parameters num_params = sum(p.numel() for p in model.parameters()) self.logger.info(f"Model created with {num_params:,} parameters on {self.device}") self.model = model return model except Exception as e: raise ModelError(f"Failed to create model: {e}") def load_checkpoint(self, checkpoint_path: Union[str, Path], use_pretrained: bool = False) -> nn.Module: """Load model from checkpoint.""" if self.model is None: self.model = self.create_model() if use_pretrained: self.logger.info("Using pretrained base model (no custom training)") return self.model try: checkpoint = safe_load_model_checkpoint(checkpoint_path, self.device) if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: self.model.load_state_dict(checkpoint['model_state_dict']) self.logger.info(f"Loaded model checkpoint from {checkpoint_path}") else: self.model.load_state_dict(checkpoint) self.logger.info(f"Loaded model state dict from {checkpoint_path}") return self.model except Exception as e: raise ModelError(f"Failed to load checkpoint: {e}") def save_checkpoint(self, filepath: Union[str, Path], optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, epoch: int = 0, best_metric: float = 0.0, train_loss: float = 0.0) -> None: """Save model checkpoint with metadata.""" if self.model is None: raise ModelError("No model to save") try: checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'best_metric': best_metric, 'train_loss': train_loss, 'config': { 'num_classes': self.config.model.num_classes, 'device': str(self.device) } } if optimizer: checkpoint['optimizer_state_dict'] = optimizer.state_dict() if scheduler: checkpoint['scheduler_state_dict'] = scheduler.state_dict() torch.save(checkpoint, filepath) self.logger.info(f"Checkpoint saved: {filepath}") except Exception as e: raise ModelError(f"Failed to save checkpoint: {e}") def get_model_info(self) -> dict: """Get model information and statistics.""" if self.model is None: raise ModelError("No model loaded") total_params = sum(p.numel() for p in self.model.parameters()) trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) return { 'architecture': 'MaskR-CNN ResNet50 FPN', 'num_classes': self.config.model.num_classes, 'device': str(self.device), 'total_parameters': total_params, 'trainable_parameters': trainable_params, 'parameter_size_mb': total_params * 4 / (1024**2), # Assuming float32 'mode': 'training' if self.model.training else 'evaluation' } def set_training_mode(self, mode: bool = True) -> None: """Set model training mode.""" if self.model is None: raise ModelError("No model loaded") if mode: self.model.train() self.logger.debug("Model set to training mode") else: self.model.eval() self.logger.debug("Model set to evaluation mode") def freeze_backbone(self, freeze: bool = True) -> None: """Freeze/unfreeze model backbone.""" if self.model is None: raise ModelError("No model loaded") for param in self.model.backbone.parameters(): param.requires_grad = not freeze status = "frozen" if freeze else "unfrozen" self.logger.info(f"Backbone {status}") def get_optimizer(self) -> torch.optim.Optimizer: """Get SGD optimizer configured from config.""" if self.model is None: raise ModelError("No model loaded") return torch.optim.SGD( self.model.parameters(), lr=self.config.training.learning_rate, momentum=self.config.training.momentum, weight_decay=self.config.training.weight_decay ) def get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler: """Get learning rate scheduler configured from config.""" return torch.optim.lr_scheduler.StepLR( optimizer, step_size=self.config.training.lr_scheduler_step, gamma=self.config.training.lr_scheduler_gamma ) def count_parameters(self, trainable_only: bool = False) -> int: """Count model parameters.""" if self.model is None: raise ModelError("No model loaded") if trainable_only: return sum(p.numel() for p in self.model.parameters() if p.requires_grad) else: return sum(p.numel() for p in self.model.parameters()) def get_device(self) -> torch.device: """Get model device.""" return self.device def to_device(self, device: Union[str, torch.device]) -> None: """Move model to specified device.""" if self.model is None: raise ModelError("No model loaded") self.device = validate_device(device) self.model.to(self.device) self.logger.info(f"Model moved to {self.device}") def create_model(config: BeanVisionConfig) -> BeanModel: """Factory function to create BeanModel instance.""" return BeanModel(config)