| """Core training functionality for bean detection models.""" |
|
|
| |
| import logging |
| import time |
| from pathlib import Path |
| from typing import Any, Dict, Optional, Tuple |
|
|
| |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| |
| from .timing import TimingProfiler, TimingStats |
|
|
|
|
| class BeanTrainer: |
| """Trainer class for bean detection models.""" |
|
|
| def __init__(self, model: nn.Module, device: torch.device, config: Optional[Dict[str, Any]] = None): |
| """Initialize trainer. |
| |
| Args: |
| model: PyTorch model to train |
| device: Device to use for training |
| config: Optional configuration dictionary |
| """ |
| self.model = model |
| self.device = device |
| self.config = config or {} |
| self.logger = logging.getLogger(__name__) |
|
|
| def train_one_epoch( |
| self, |
| train_loader: DataLoader, |
| optimizer: torch.optim.Optimizer, |
| epoch: int, |
| profiler: Optional[TimingProfiler] = None, |
| grad_clip_value: Optional[float] = None, |
| log_freq: int = 10 |
| ) -> Tuple[float, Dict[str, float], TimingStats]: |
| """Train for one epoch with simple timing. |
| |
| Args: |
| train_loader: Training data loader |
| optimizer: Optimizer |
| epoch: Current epoch number |
| profiler: Optional timing profiler (unused for compatibility) |
| grad_clip_value: Optional gradient clipping value |
| log_freq: Frequency of logging |
| |
| Returns: |
| Tuple of (average loss, loss components dict, timing stats) |
| """ |
| self.model.train() |
| total_loss = 0.0 |
| loss_components_total = {} |
| num_batches = 0 |
|
|
| train_start = time.time() |
|
|
| |
| pbar = tqdm(train_loader, desc=f'Epoch {epoch}', unit='batch') |
|
|
| for batch_idx, (images, targets) in enumerate(pbar): |
| images = [img.to(self.device) for img in images] |
| targets = [{k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] |
|
|
| |
| optimizer.zero_grad() |
| loss_dict = self.model(images, targets) |
| losses = sum(loss for loss in loss_dict.values()) |
|
|
| |
| if torch.isnan(losses) or torch.isinf(losses): |
| pbar.write(f"Warning: NaN or Inf loss detected in batch {batch_idx+1}, skipping...") |
| continue |
|
|
| loss_value = losses.item() |
|
|
| |
| losses.backward() |
|
|
| |
| if grad_clip_value: |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=grad_clip_value) |
|
|
| |
| optimizer.step() |
|
|
| |
| total_loss += loss_value |
| num_batches += 1 |
|
|
| |
| loss_components = {k: v.item() for k, v in loss_dict.items()} |
| for k, v in loss_components.items(): |
| if k not in loss_components_total: |
| loss_components_total[k] = 0.0 |
| loss_components_total[k] += v |
|
|
| |
| pbar.set_postfix({'loss': loss_value}) |
|
|
| |
| train_time = time.time() - train_start |
|
|
| |
| avg_loss_components = {k: v / num_batches for k, v in loss_components_total.items()} if num_batches > 0 else {} |
| avg_loss = total_loss / num_batches if num_batches > 0 else float('inf') |
|
|
| |
| timing_stats = TimingStats() |
| timing_stats.train_time = train_time |
|
|
| return avg_loss, avg_loss_components, timing_stats |
|
|
| def validate( |
| self, |
| val_loader: DataLoader, |
| profiler: Optional[TimingProfiler] = None |
| ) -> Tuple[float, Dict[str, float], TimingStats]: |
| """Validate model and compute loss with timing. |
| |
| Args: |
| val_loader: Validation data loader |
| profiler: Optional timing profiler (unused for compatibility) |
| |
| Returns: |
| Tuple of (average loss, loss components dict, timing stats) |
| """ |
| self.model.train() |
| total_loss = 0.0 |
| loss_components_total = {} |
| num_batches = 0 |
|
|
| val_start = time.time() |
|
|
| with torch.no_grad(): |
| for batch_data in val_loader: |
| images, targets = batch_data |
| images = [img.to(self.device) for img in images] |
| targets = [{k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] |
|
|
| |
| loss_dict = self.model(images, targets) |
| losses = sum(loss for loss in loss_dict.values()) |
|
|
| |
| if not torch.isnan(losses) and not torch.isinf(losses): |
| total_loss += losses.item() |
| num_batches += 1 |
|
|
| |
| for k, v in loss_dict.items(): |
| if k not in loss_components_total: |
| loss_components_total[k] = 0.0 |
| loss_components_total[k] += v.item() |
|
|
| self.model.eval() |
|
|
| |
| val_time = time.time() - val_start |
| avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 |
| avg_loss_components = {k: v / num_batches for k, v in loss_components_total.items()} if num_batches > 0 else {} |
|
|
| |
| timing_stats = TimingStats() |
| timing_stats.val_time = val_time |
|
|
| return avg_loss, avg_loss_components, timing_stats |