File size: 8,522 Bytes
196c526 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | #!/usr/bin/env python
"""Model management module for Bean Vision project."""
from pathlib import Path
from typing import Optional, Union
import torch
import torch.nn as nn
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from bean_vision.config import BeanVisionConfig
from bean_vision.utils.logging import get_logger
from bean_vision.utils.misc import ModelError, validate_device, safe_load_model_checkpoint
class BeanModel:
"""Bean detection model wrapper with utilities."""
def __init__(self, config: BeanVisionConfig):
self.config = config
self.logger = get_logger(self.__class__.__name__)
self.device = validate_device(config.model.device)
self.model: Optional[nn.Module] = None
# Initialize model
self.model = self.create_model()
def create_model(self) -> nn.Module:
"""Create MaskR-CNN model with modified heads."""
try:
# Load pre-trained model with increased detection limits for high bean counts
model = maskrcnn_resnet50_fpn(
weights="DEFAULT",
rpn_pre_nms_top_n_train=6000, # Increased from 2000
rpn_pre_nms_top_n_test=3000, # Increased from 1000
rpn_post_nms_top_n_train=4000, # Increased from 2000
rpn_post_nms_top_n_test=2000, # Increased from 1000
box_detections_per_img=1000, # Increased from 100
box_score_thresh=0.05 # Lower threshold for more detections
)
# Replace classifier head
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
in_features, self.config.model.num_classes
)
# Replace mask predictor head
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(
in_features_mask, hidden_layer, self.config.model.num_classes
)
# Move to device
model.to(self.device)
# Count parameters
num_params = sum(p.numel() for p in model.parameters())
self.logger.info(f"Model created with {num_params:,} parameters on {self.device}")
self.model = model
return model
except Exception as e:
raise ModelError(f"Failed to create model: {e}")
def load_checkpoint(self, checkpoint_path: Union[str, Path], use_pretrained: bool = False) -> nn.Module:
"""Load model from checkpoint."""
if self.model is None:
self.model = self.create_model()
if use_pretrained:
self.logger.info("Using pretrained base model (no custom training)")
return self.model
try:
checkpoint = safe_load_model_checkpoint(checkpoint_path, self.device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
self.logger.info(f"Loaded model checkpoint from {checkpoint_path}")
else:
self.model.load_state_dict(checkpoint)
self.logger.info(f"Loaded model state dict from {checkpoint_path}")
return self.model
except Exception as e:
raise ModelError(f"Failed to load checkpoint: {e}")
def save_checkpoint(self,
filepath: Union[str, Path],
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
epoch: int = 0,
best_metric: float = 0.0,
train_loss: float = 0.0) -> None:
"""Save model checkpoint with metadata."""
if self.model is None:
raise ModelError("No model to save")
try:
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'best_metric': best_metric,
'train_loss': train_loss,
'config': {
'num_classes': self.config.model.num_classes,
'device': str(self.device)
}
}
if optimizer:
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
if scheduler:
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
torch.save(checkpoint, filepath)
self.logger.info(f"Checkpoint saved: {filepath}")
except Exception as e:
raise ModelError(f"Failed to save checkpoint: {e}")
def get_model_info(self) -> dict:
"""Get model information and statistics."""
if self.model is None:
raise ModelError("No model loaded")
total_params = sum(p.numel() for p in self.model.parameters())
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
return {
'architecture': 'MaskR-CNN ResNet50 FPN',
'num_classes': self.config.model.num_classes,
'device': str(self.device),
'total_parameters': total_params,
'trainable_parameters': trainable_params,
'parameter_size_mb': total_params * 4 / (1024**2), # Assuming float32
'mode': 'training' if self.model.training else 'evaluation'
}
def set_training_mode(self, mode: bool = True) -> None:
"""Set model training mode."""
if self.model is None:
raise ModelError("No model loaded")
if mode:
self.model.train()
self.logger.debug("Model set to training mode")
else:
self.model.eval()
self.logger.debug("Model set to evaluation mode")
def freeze_backbone(self, freeze: bool = True) -> None:
"""Freeze/unfreeze model backbone."""
if self.model is None:
raise ModelError("No model loaded")
for param in self.model.backbone.parameters():
param.requires_grad = not freeze
status = "frozen" if freeze else "unfrozen"
self.logger.info(f"Backbone {status}")
def get_optimizer(self) -> torch.optim.Optimizer:
"""Get SGD optimizer configured from config."""
if self.model is None:
raise ModelError("No model loaded")
return torch.optim.SGD(
self.model.parameters(),
lr=self.config.training.learning_rate,
momentum=self.config.training.momentum,
weight_decay=self.config.training.weight_decay
)
def get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
"""Get learning rate scheduler configured from config."""
return torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=self.config.training.lr_scheduler_step,
gamma=self.config.training.lr_scheduler_gamma
)
def count_parameters(self, trainable_only: bool = False) -> int:
"""Count model parameters."""
if self.model is None:
raise ModelError("No model loaded")
if trainable_only:
return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in self.model.parameters())
def get_device(self) -> torch.device:
"""Get model device."""
return self.device
def to_device(self, device: Union[str, torch.device]) -> None:
"""Move model to specified device."""
if self.model is None:
raise ModelError("No model loaded")
self.device = validate_device(device)
self.model.to(self.device)
self.logger.info(f"Model moved to {self.device}")
def create_model(config: BeanVisionConfig) -> BeanModel:
"""Factory function to create BeanModel instance."""
return BeanModel(config) |