Kunitomi's picture
Upload folder using huggingface_hub
196c526 verified
"""Core training functionality for bean detection models."""
# Standard library imports
import logging
import time
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
# Third-party imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
# Local imports
from .timing import TimingProfiler, TimingStats
class BeanTrainer:
"""Trainer class for bean detection models."""
def __init__(self, model: nn.Module, device: torch.device, config: Optional[Dict[str, Any]] = None):
"""Initialize trainer.
Args:
model: PyTorch model to train
device: Device to use for training
config: Optional configuration dictionary
"""
self.model = model
self.device = device
self.config = config or {}
self.logger = logging.getLogger(__name__)
def train_one_epoch(
self,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
epoch: int,
profiler: Optional[TimingProfiler] = None,
grad_clip_value: Optional[float] = None,
log_freq: int = 10
) -> Tuple[float, Dict[str, float], TimingStats]:
"""Train for one epoch with simple timing.
Args:
train_loader: Training data loader
optimizer: Optimizer
epoch: Current epoch number
profiler: Optional timing profiler (unused for compatibility)
grad_clip_value: Optional gradient clipping value
log_freq: Frequency of logging
Returns:
Tuple of (average loss, loss components dict, timing stats)
"""
self.model.train()
total_loss = 0.0
loss_components_total = {}
num_batches = 0
train_start = time.time()
# Create progress bar
pbar = tqdm(train_loader, desc=f'Epoch {epoch}', unit='batch')
for batch_idx, (images, targets) in enumerate(pbar):
images = [img.to(self.device) for img in images]
targets = [{k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
# Forward pass
optimizer.zero_grad()
loss_dict = self.model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# Check for NaN losses and skip if found
if torch.isnan(losses) or torch.isinf(losses):
pbar.write(f"Warning: NaN or Inf loss detected in batch {batch_idx+1}, skipping...")
continue
loss_value = losses.item()
# Backward pass
losses.backward()
# Gradient clipping for stability
if grad_clip_value:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=grad_clip_value)
# Optimizer step
optimizer.step()
# Accumulate metrics
total_loss += loss_value
num_batches += 1
# Accumulate loss components
loss_components = {k: v.item() for k, v in loss_dict.items()}
for k, v in loss_components.items():
if k not in loss_components_total:
loss_components_total[k] = 0.0
loss_components_total[k] += v
# Update progress bar with current loss
pbar.set_postfix({'loss': loss_value})
# Calculate aggregated metrics
train_time = time.time() - train_start
# Average loss components
avg_loss_components = {k: v / num_batches for k, v in loss_components_total.items()} if num_batches > 0 else {}
avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
# Create timing stats
timing_stats = TimingStats()
timing_stats.train_time = train_time
return avg_loss, avg_loss_components, timing_stats
def validate(
self,
val_loader: DataLoader,
profiler: Optional[TimingProfiler] = None
) -> Tuple[float, Dict[str, float], TimingStats]:
"""Validate model and compute loss with timing.
Args:
val_loader: Validation data loader
profiler: Optional timing profiler (unused for compatibility)
Returns:
Tuple of (average loss, loss components dict, timing stats)
"""
self.model.train() # Need train mode to compute loss
total_loss = 0.0
loss_components_total = {}
num_batches = 0
val_start = time.time()
with torch.no_grad():
for batch_data in val_loader:
images, targets = batch_data
images = [img.to(self.device) for img in images]
targets = [{k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
# Forward pass
loss_dict = self.model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# Loss aggregation
if not torch.isnan(losses) and not torch.isinf(losses):
total_loss += losses.item()
num_batches += 1
# Accumulate loss components
for k, v in loss_dict.items():
if k not in loss_components_total:
loss_components_total[k] = 0.0
loss_components_total[k] += v.item()
self.model.eval() # Reset to eval mode
# Calculate averages
val_time = time.time() - val_start
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
avg_loss_components = {k: v / num_batches for k, v in loss_components_total.items()} if num_batches > 0 else {}
# Create timing stats
timing_stats = TimingStats()
timing_stats.val_time = val_time
return avg_loss, avg_loss_components, timing_stats