Spaces:
Sleeping
Sleeping
| """ | |
| Training Control Manager - Handles pause/resume and live parameter updates | |
| """ | |
| import threading | |
| import logging | |
| from typing import Dict, Any, Optional | |
| logger = logging.getLogger(__name__) | |
| class TrainingController: | |
| """ | |
| Manages training state and live parameter updates | |
| """ | |
| def __init__(self, job_id: str): | |
| self.job_id = job_id | |
| # Threading control | |
| self.pause_event = threading.Event() | |
| self.pause_event.set() # Start in running state | |
| # State | |
| self.is_paused = False | |
| self.trainer = None # Will be set by trainer | |
| logger.info(f"[{job_id}] TrainingController initialized") | |
| def pause(self): | |
| """Pause training""" | |
| if not self.is_paused: | |
| self.pause_event.clear() | |
| self.is_paused = True | |
| logger.info(f"[{self.job_id}] Training PAUSED") | |
| return {"status": "paused"} | |
| return {"status": "already_paused"} | |
| def resume(self): | |
| """Resume training""" | |
| if self.is_paused: | |
| self.pause_event.set() | |
| self.is_paused = False | |
| logger.info(f"[{self.job_id}] Training RESUMED") | |
| return {"status": "resumed"} | |
| return {"status": "already_running"} | |
| def wait_if_paused(self): | |
| """ | |
| Block execution if paused. Call this in training loop. | |
| """ | |
| self.pause_event.wait() | |
| def update_hyperparameters(self, updates: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Hot-swap hyperparameters in the running trainer | |
| Args: | |
| updates: Dictionary of parameter updates | |
| e.g., {"learning_rate": 5e-5, "weight_decay": 0.02} | |
| Returns: | |
| Status dictionary with applied changes | |
| """ | |
| if not self.trainer: | |
| return {"error": "Trainer not initialized"} | |
| applied = {} | |
| errors = [] | |
| try: | |
| # Access optimizer param groups | |
| for param_group in self.trainer.optimizer.param_groups: | |
| # Update learning rate | |
| if 'learning_rate' in updates: | |
| new_lr = float(updates['learning_rate']) | |
| old_lr = param_group['lr'] | |
| param_group['lr'] = new_lr | |
| applied['learning_rate'] = {'old': old_lr, 'new': new_lr} | |
| logger.info(f"[{self.job_id}] LR updated: {old_lr} → {new_lr}") | |
| # Update weight decay | |
| if 'weight_decay' in updates: | |
| new_wd = float(updates['weight_decay']) | |
| old_wd = param_group.get('weight_decay', 0.0) | |
| param_group['weight_decay'] = new_wd | |
| applied['weight_decay'] = {'old': old_wd, 'new': new_wd} | |
| logger.info(f"[{self.job_id}] Weight decay updated: {old_wd} → {new_wd}") | |
| # Note: Other params like dropout, batch_size require trainer restart | |
| unsupported = set(updates.keys()) - {'learning_rate', 'weight_decay'} | |
| if unsupported: | |
| errors.append(f"Unsupported hot-swap params: {list(unsupported)}") | |
| logger.warning(f"[{self.job_id}] Cannot hot-swap: {unsupported}") | |
| return { | |
| "status": "success" if applied else "no_changes", | |
| "applied": applied, | |
| "errors": errors if errors else None | |
| } | |
| except Exception as e: | |
| logger.error(f"[{self.job_id}] Failed to update hyperparameters: {e}") | |
| return {"error": str(e)} | |
| def get_status(self) -> Dict[str, Any]: | |
| """Get current training control status""" | |
| return { | |
| "is_paused": self.is_paused, | |
| "can_update_params": self.trainer is not None | |
| } | |
| # Global registry of active controllers | |
| _active_controllers: Dict[str, TrainingController] = {} | |
| def get_controller(job_id: str) -> Optional[TrainingController]: | |
| """Get training controller for a job""" | |
| return _active_controllers.get(job_id) | |
| def create_controller(job_id: str) -> TrainingController: | |
| """Create and register a new training controller""" | |
| controller = TrainingController(job_id) | |
| _active_controllers[job_id] = controller | |
| return controller | |
| def remove_controller(job_id: str): | |
| """Remove controller when training completes""" | |
| if job_id in _active_controllers: | |
| del _active_controllers[job_id] | |
| logger.info(f"[{job_id}] Controller removed") | |