modelforge-backend / backend /services /training_controller.py
ModelForge CI
deploy: 2026-06-19 19:24 UTC
6761f70
Raw
History Blame Contribute Delete
4.64 kB
"""
Training Control Manager - Handles pause/resume and live parameter updates
"""
import threading
import logging
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
class TrainingController:
"""
Manages training state and live parameter updates
"""
def __init__(self, job_id: str):
self.job_id = job_id
# Threading control
self.pause_event = threading.Event()
self.pause_event.set() # Start in running state
# State
self.is_paused = False
self.trainer = None # Will be set by trainer
logger.info(f"[{job_id}] TrainingController initialized")
def pause(self):
"""Pause training"""
if not self.is_paused:
self.pause_event.clear()
self.is_paused = True
logger.info(f"[{self.job_id}] Training PAUSED")
return {"status": "paused"}
return {"status": "already_paused"}
def resume(self):
"""Resume training"""
if self.is_paused:
self.pause_event.set()
self.is_paused = False
logger.info(f"[{self.job_id}] Training RESUMED")
return {"status": "resumed"}
return {"status": "already_running"}
def wait_if_paused(self):
"""
Block execution if paused. Call this in training loop.
"""
self.pause_event.wait()
def update_hyperparameters(self, updates: Dict[str, Any]) -> Dict[str, Any]:
"""
Hot-swap hyperparameters in the running trainer
Args:
updates: Dictionary of parameter updates
e.g., {"learning_rate": 5e-5, "weight_decay": 0.02}
Returns:
Status dictionary with applied changes
"""
if not self.trainer:
return {"error": "Trainer not initialized"}
applied = {}
errors = []
try:
# Access optimizer param groups
for param_group in self.trainer.optimizer.param_groups:
# Update learning rate
if 'learning_rate' in updates:
new_lr = float(updates['learning_rate'])
old_lr = param_group['lr']
param_group['lr'] = new_lr
applied['learning_rate'] = {'old': old_lr, 'new': new_lr}
logger.info(f"[{self.job_id}] LR updated: {old_lr}{new_lr}")
# Update weight decay
if 'weight_decay' in updates:
new_wd = float(updates['weight_decay'])
old_wd = param_group.get('weight_decay', 0.0)
param_group['weight_decay'] = new_wd
applied['weight_decay'] = {'old': old_wd, 'new': new_wd}
logger.info(f"[{self.job_id}] Weight decay updated: {old_wd}{new_wd}")
# Note: Other params like dropout, batch_size require trainer restart
unsupported = set(updates.keys()) - {'learning_rate', 'weight_decay'}
if unsupported:
errors.append(f"Unsupported hot-swap params: {list(unsupported)}")
logger.warning(f"[{self.job_id}] Cannot hot-swap: {unsupported}")
return {
"status": "success" if applied else "no_changes",
"applied": applied,
"errors": errors if errors else None
}
except Exception as e:
logger.error(f"[{self.job_id}] Failed to update hyperparameters: {e}")
return {"error": str(e)}
def get_status(self) -> Dict[str, Any]:
"""Get current training control status"""
return {
"is_paused": self.is_paused,
"can_update_params": self.trainer is not None
}
# Global registry of active controllers
_active_controllers: Dict[str, TrainingController] = {}
def get_controller(job_id: str) -> Optional[TrainingController]:
"""Get training controller for a job"""
return _active_controllers.get(job_id)
def create_controller(job_id: str) -> TrainingController:
"""Create and register a new training controller"""
controller = TrainingController(job_id)
_active_controllers[job_id] = controller
return controller
def remove_controller(job_id: str):
"""Remove controller when training completes"""
if job_id in _active_controllers:
del _active_controllers[job_id]
logger.info(f"[{job_id}] Controller removed")