File size: 8,522 Bytes
196c526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#!/usr/bin/env python
"""Model management module for Bean Vision project."""

from pathlib import Path
from typing import Optional, Union
import torch
import torch.nn as nn
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn

from bean_vision.config import BeanVisionConfig
from bean_vision.utils.logging import get_logger
from bean_vision.utils.misc import ModelError, validate_device, safe_load_model_checkpoint


class BeanModel:
    """Bean detection model wrapper with utilities."""
    
    def __init__(self, config: BeanVisionConfig):
        self.config = config
        self.logger = get_logger(self.__class__.__name__)
        self.device = validate_device(config.model.device)
        self.model: Optional[nn.Module] = None
        # Initialize model
        self.model = self.create_model()
    
    def create_model(self) -> nn.Module:
        """Create MaskR-CNN model with modified heads."""
        try:
            # Load pre-trained model with increased detection limits for high bean counts
            model = maskrcnn_resnet50_fpn(
                weights="DEFAULT",
                rpn_pre_nms_top_n_train=6000,    # Increased from 2000
                rpn_pre_nms_top_n_test=3000,     # Increased from 1000
                rpn_post_nms_top_n_train=4000,   # Increased from 2000
                rpn_post_nms_top_n_test=2000,    # Increased from 1000
                box_detections_per_img=1000,     # Increased from 100
                box_score_thresh=0.05            # Lower threshold for more detections
            )
            
            # Replace classifier head
            in_features = model.roi_heads.box_predictor.cls_score.in_features
            model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
                in_features, self.config.model.num_classes
            )
            
            # Replace mask predictor head
            in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
            hidden_layer = 256
            model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(
                in_features_mask, hidden_layer, self.config.model.num_classes
            )
            
            # Move to device
            model.to(self.device)
            
            # Count parameters
            num_params = sum(p.numel() for p in model.parameters())
            self.logger.info(f"Model created with {num_params:,} parameters on {self.device}")
            
            self.model = model
            return model
        
        except Exception as e:
            raise ModelError(f"Failed to create model: {e}")
    
    def load_checkpoint(self, checkpoint_path: Union[str, Path], use_pretrained: bool = False) -> nn.Module:
        """Load model from checkpoint."""
        if self.model is None:
            self.model = self.create_model()
        
        if use_pretrained:
            self.logger.info("Using pretrained base model (no custom training)")
            return self.model
        
        try:
            checkpoint = safe_load_model_checkpoint(checkpoint_path, self.device)
            
            if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.logger.info(f"Loaded model checkpoint from {checkpoint_path}")
            else:
                self.model.load_state_dict(checkpoint)
                self.logger.info(f"Loaded model state dict from {checkpoint_path}")
            
            return self.model
        
        except Exception as e:
            raise ModelError(f"Failed to load checkpoint: {e}")
    
    def save_checkpoint(self, 
                       filepath: Union[str, Path],
                       optimizer: Optional[torch.optim.Optimizer] = None,
                       scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
                       epoch: int = 0,
                       best_metric: float = 0.0,
                       train_loss: float = 0.0) -> None:
        """Save model checkpoint with metadata."""
        if self.model is None:
            raise ModelError("No model to save")
        
        try:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'best_metric': best_metric,
                'train_loss': train_loss,
                'config': {
                    'num_classes': self.config.model.num_classes,
                    'device': str(self.device)
                }
            }
            
            if optimizer:
                checkpoint['optimizer_state_dict'] = optimizer.state_dict()
            if scheduler:
                checkpoint['scheduler_state_dict'] = scheduler.state_dict()
            
            torch.save(checkpoint, filepath)
            self.logger.info(f"Checkpoint saved: {filepath}")
        
        except Exception as e:
            raise ModelError(f"Failed to save checkpoint: {e}")
    
    def get_model_info(self) -> dict:
        """Get model information and statistics."""
        if self.model is None:
            raise ModelError("No model loaded")
        
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        
        return {
            'architecture': 'MaskR-CNN ResNet50 FPN',
            'num_classes': self.config.model.num_classes,
            'device': str(self.device),
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'parameter_size_mb': total_params * 4 / (1024**2),  # Assuming float32
            'mode': 'training' if self.model.training else 'evaluation'
        }
    
    def set_training_mode(self, mode: bool = True) -> None:
        """Set model training mode."""
        if self.model is None:
            raise ModelError("No model loaded")
        
        if mode:
            self.model.train()
            self.logger.debug("Model set to training mode")
        else:
            self.model.eval()
            self.logger.debug("Model set to evaluation mode")
    
    def freeze_backbone(self, freeze: bool = True) -> None:
        """Freeze/unfreeze model backbone."""
        if self.model is None:
            raise ModelError("No model loaded")
        
        for param in self.model.backbone.parameters():
            param.requires_grad = not freeze
        
        status = "frozen" if freeze else "unfrozen"
        self.logger.info(f"Backbone {status}")
    
    def get_optimizer(self) -> torch.optim.Optimizer:
        """Get SGD optimizer configured from config."""
        if self.model is None:
            raise ModelError("No model loaded")
        
        return torch.optim.SGD(
            self.model.parameters(),
            lr=self.config.training.learning_rate,
            momentum=self.config.training.momentum,
            weight_decay=self.config.training.weight_decay
        )
    
    def get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
        """Get learning rate scheduler configured from config."""
        return torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=self.config.training.lr_scheduler_step,
            gamma=self.config.training.lr_scheduler_gamma
        )
    
    def count_parameters(self, trainable_only: bool = False) -> int:
        """Count model parameters."""
        if self.model is None:
            raise ModelError("No model loaded")
        
        if trainable_only:
            return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        else:
            return sum(p.numel() for p in self.model.parameters())
    
    def get_device(self) -> torch.device:
        """Get model device."""
        return self.device
    
    def to_device(self, device: Union[str, torch.device]) -> None:
        """Move model to specified device."""
        if self.model is None:
            raise ModelError("No model loaded")
        
        self.device = validate_device(device)
        self.model.to(self.device)
        self.logger.info(f"Model moved to {self.device}")


def create_model(config: BeanVisionConfig) -> BeanModel:
    """Factory function to create BeanModel instance."""
    return BeanModel(config)