"""Training execution helper (Quick Start).""" from pathlib import Path from typing import Optional import torch import torch.nn as nn from torch.utils.data import DataLoader from llm_lab.config import TrainConfig from .trainer import Trainer from .checkpoint import CheckpointManager from llm_lab.utils import auto_configure, set_seed def _setup_and_train( model: nn.Module, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], config: TrainConfig, seq_len: int, auto_config: bool, fallback_checkpoint_dir: str, ) -> Trainer: """Shared setup logic: auto-config, Drive check, seed, train.""" if auto_config: config = auto_configure(config) # Check Google Drive mount (Colab) if "/content/drive" in config.checkpoint_dir: drive_path = Path("/content/drive/MyDrive") if not drive_path.exists(): print("\n⚠️ Google Drive is not mounted!") print(" Run in Colab: from google.colab import drive; drive.mount('/content/drive')") print(" Switching to local path.") config.checkpoint_dir = fallback_checkpoint_dir set_seed(config.seed) trainer = Trainer(model, train_dataloader, val_dataloader, config, seq_len) trainer.train() return trainer def start_training( model: nn.Module, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader] = None, config: Optional[TrainConfig] = None, seq_len: int = 2048, auto_config: bool = True, ) -> Trainer: """Starts training (one-line execution). Usage (Colab): ```python from model import LLMModel, ModelConfig from data_pipeline import setup_data_pipeline, DataConfig from trainer import start_training, TrainConfig # 1. Create model model_config = ModelConfig.base_1b() model = LLMModel(model_config) # 2. Data pipeline tok, train_dl, val_dl = setup_data_pipeline("pretrained") # 3. Start training (automatic checkpoint restoration) trainer = start_training(model, train_dl, val_dl) ``` """ return _setup_and_train( model=model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, config=config or TrainConfig(), seq_len=seq_len, auto_config=auto_config, fallback_checkpoint_dir="./checkpoints", ) def start_cpt( model: nn.Module, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader] = None, config: Optional[TrainConfig] = None, pretrained_checkpoint_dir: str = "/content/drive/MyDrive/llm-1b-lab/checkpoints", seq_len: int = 2048, auto_config: bool = True, ) -> Trainer: """Starts Continual Pre-Training from a pretrained checkpoint. Loads pretrained model weights, creates a fresh optimizer, and trains with the new config (typically lower LR + new data). Args: model: The LLM model (same architecture as the pretrained checkpoint). train_dataloader: Mixed data dataloader (code + general). val_dataloader: Validation dataloader. config: CPT training config. Defaults to TrainConfig.code_cpt_1b(). pretrained_checkpoint_dir: Path to the base pretrained checkpoint directory. seq_len: Sequence length (should match model config). auto_config: Whether to auto-detect GPU and adjust settings. Returns: Trainer instance after training completes. Usage (Colab): ```python from llm_lab.config import ModelConfig, DataConfig, TrainConfig from llm_lab.model import LLMModel from llm_lab.data import setup_cpt_data_pipeline from llm_lab.training import start_cpt model = LLMModel(ModelConfig.base_1b()) tok, train_dl, val_dl = setup_cpt_data_pipeline() trainer = start_cpt(model, train_dl, val_dl) ``` """ # Load pretrained weights (model only, no optimizer) CheckpointManager.load_model_only( model=model, checkpoint_dir=pretrained_checkpoint_dir, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), ) # Note: Trainer._try_resume() will look in the CPT checkpoint dir. # If no CPT checkpoint exists yet, it starts from step 0 with # the pretrained weights we just loaded. return _setup_and_train( model=model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, config=config or TrainConfig.code_cpt_1b(), seq_len=seq_len, auto_config=auto_config, fallback_checkpoint_dir="./checkpoints-cpt-code", )