Kunitomi's picture
Upload folder using huggingface_hub
196c526 verified
#!/usr/bin/env python
"""Model management module for Bean Vision project."""
from pathlib import Path
from typing import Optional, Union
import torch
import torch.nn as nn
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from bean_vision.config import BeanVisionConfig
from bean_vision.utils.logging import get_logger
from bean_vision.utils.misc import ModelError, validate_device, safe_load_model_checkpoint
class BeanModel:
"""Bean detection model wrapper with utilities."""
def __init__(self, config: BeanVisionConfig):
self.config = config
self.logger = get_logger(self.__class__.__name__)
self.device = validate_device(config.model.device)
self.model: Optional[nn.Module] = None
# Initialize model
self.model = self.create_model()
def create_model(self) -> nn.Module:
"""Create MaskR-CNN model with modified heads."""
try:
# Load pre-trained model with increased detection limits for high bean counts
model = maskrcnn_resnet50_fpn(
weights="DEFAULT",
rpn_pre_nms_top_n_train=6000, # Increased from 2000
rpn_pre_nms_top_n_test=3000, # Increased from 1000
rpn_post_nms_top_n_train=4000, # Increased from 2000
rpn_post_nms_top_n_test=2000, # Increased from 1000
box_detections_per_img=1000, # Increased from 100
box_score_thresh=0.05 # Lower threshold for more detections
)
# Replace classifier head
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
in_features, self.config.model.num_classes
)
# Replace mask predictor head
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(
in_features_mask, hidden_layer, self.config.model.num_classes
)
# Move to device
model.to(self.device)
# Count parameters
num_params = sum(p.numel() for p in model.parameters())
self.logger.info(f"Model created with {num_params:,} parameters on {self.device}")
self.model = model
return model
except Exception as e:
raise ModelError(f"Failed to create model: {e}")
def load_checkpoint(self, checkpoint_path: Union[str, Path], use_pretrained: bool = False) -> nn.Module:
"""Load model from checkpoint."""
if self.model is None:
self.model = self.create_model()
if use_pretrained:
self.logger.info("Using pretrained base model (no custom training)")
return self.model
try:
checkpoint = safe_load_model_checkpoint(checkpoint_path, self.device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
self.logger.info(f"Loaded model checkpoint from {checkpoint_path}")
else:
self.model.load_state_dict(checkpoint)
self.logger.info(f"Loaded model state dict from {checkpoint_path}")
return self.model
except Exception as e:
raise ModelError(f"Failed to load checkpoint: {e}")
def save_checkpoint(self,
filepath: Union[str, Path],
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
epoch: int = 0,
best_metric: float = 0.0,
train_loss: float = 0.0) -> None:
"""Save model checkpoint with metadata."""
if self.model is None:
raise ModelError("No model to save")
try:
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'best_metric': best_metric,
'train_loss': train_loss,
'config': {
'num_classes': self.config.model.num_classes,
'device': str(self.device)
}
}
if optimizer:
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
if scheduler:
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
torch.save(checkpoint, filepath)
self.logger.info(f"Checkpoint saved: {filepath}")
except Exception as e:
raise ModelError(f"Failed to save checkpoint: {e}")
def get_model_info(self) -> dict:
"""Get model information and statistics."""
if self.model is None:
raise ModelError("No model loaded")
total_params = sum(p.numel() for p in self.model.parameters())
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
return {
'architecture': 'MaskR-CNN ResNet50 FPN',
'num_classes': self.config.model.num_classes,
'device': str(self.device),
'total_parameters': total_params,
'trainable_parameters': trainable_params,
'parameter_size_mb': total_params * 4 / (1024**2), # Assuming float32
'mode': 'training' if self.model.training else 'evaluation'
}
def set_training_mode(self, mode: bool = True) -> None:
"""Set model training mode."""
if self.model is None:
raise ModelError("No model loaded")
if mode:
self.model.train()
self.logger.debug("Model set to training mode")
else:
self.model.eval()
self.logger.debug("Model set to evaluation mode")
def freeze_backbone(self, freeze: bool = True) -> None:
"""Freeze/unfreeze model backbone."""
if self.model is None:
raise ModelError("No model loaded")
for param in self.model.backbone.parameters():
param.requires_grad = not freeze
status = "frozen" if freeze else "unfrozen"
self.logger.info(f"Backbone {status}")
def get_optimizer(self) -> torch.optim.Optimizer:
"""Get SGD optimizer configured from config."""
if self.model is None:
raise ModelError("No model loaded")
return torch.optim.SGD(
self.model.parameters(),
lr=self.config.training.learning_rate,
momentum=self.config.training.momentum,
weight_decay=self.config.training.weight_decay
)
def get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
"""Get learning rate scheduler configured from config."""
return torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=self.config.training.lr_scheduler_step,
gamma=self.config.training.lr_scheduler_gamma
)
def count_parameters(self, trainable_only: bool = False) -> int:
"""Count model parameters."""
if self.model is None:
raise ModelError("No model loaded")
if trainable_only:
return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in self.model.parameters())
def get_device(self) -> torch.device:
"""Get model device."""
return self.device
def to_device(self, device: Union[str, torch.device]) -> None:
"""Move model to specified device."""
if self.model is None:
raise ModelError("No model loaded")
self.device = validate_device(device)
self.model.to(self.device)
self.logger.info(f"Model moved to {self.device}")
def create_model(config: BeanVisionConfig) -> BeanModel:
"""Factory function to create BeanModel instance."""
return BeanModel(config)