| """Training execution helper (Quick Start).""" |
|
|
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
|
|
| from llm_lab.config import TrainConfig |
| from .trainer import Trainer |
| from .checkpoint import CheckpointManager |
| from llm_lab.utils import auto_configure, set_seed |
|
|
|
|
| def _setup_and_train( |
| model: nn.Module, |
| train_dataloader: DataLoader, |
| val_dataloader: Optional[DataLoader], |
| config: TrainConfig, |
| seq_len: int, |
| auto_config: bool, |
| fallback_checkpoint_dir: str, |
| ) -> Trainer: |
| """Shared setup logic: auto-config, Drive check, seed, train.""" |
| if auto_config: |
| config = auto_configure(config) |
|
|
| |
| if "/content/drive" in config.checkpoint_dir: |
| drive_path = Path("/content/drive/MyDrive") |
| if not drive_path.exists(): |
| print("\n⚠️ Google Drive is not mounted!") |
| print(" Run in Colab: from google.colab import drive; drive.mount('/content/drive')") |
| print(" Switching to local path.") |
| config.checkpoint_dir = fallback_checkpoint_dir |
|
|
| set_seed(config.seed) |
|
|
| trainer = Trainer(model, train_dataloader, val_dataloader, config, seq_len) |
| trainer.train() |
|
|
| return trainer |
|
|
|
|
| def start_training( |
| model: nn.Module, |
| train_dataloader: DataLoader, |
| val_dataloader: Optional[DataLoader] = None, |
| config: Optional[TrainConfig] = None, |
| seq_len: int = 2048, |
| auto_config: bool = True, |
| ) -> Trainer: |
| """Starts training (one-line execution). |
| |
| Usage (Colab): |
| ```python |
| from model import LLMModel, ModelConfig |
| from data_pipeline import setup_data_pipeline, DataConfig |
| from trainer import start_training, TrainConfig |
| |
| # 1. Create model |
| model_config = ModelConfig.base_1b() |
| model = LLMModel(model_config) |
| |
| # 2. Data pipeline |
| tok, train_dl, val_dl = setup_data_pipeline("pretrained") |
| |
| # 3. Start training (automatic checkpoint restoration) |
| trainer = start_training(model, train_dl, val_dl) |
| ``` |
| """ |
| return _setup_and_train( |
| model=model, |
| train_dataloader=train_dataloader, |
| val_dataloader=val_dataloader, |
| config=config or TrainConfig(), |
| seq_len=seq_len, |
| auto_config=auto_config, |
| fallback_checkpoint_dir="./checkpoints", |
| ) |
|
|
|
|
| def start_cpt( |
| model: nn.Module, |
| train_dataloader: DataLoader, |
| val_dataloader: Optional[DataLoader] = None, |
| config: Optional[TrainConfig] = None, |
| pretrained_checkpoint_dir: str = "/content/drive/MyDrive/llm-1b-lab/checkpoints", |
| seq_len: int = 2048, |
| auto_config: bool = True, |
| ) -> Trainer: |
| """Starts Continual Pre-Training from a pretrained checkpoint. |
| |
| Loads pretrained model weights, creates a fresh optimizer, |
| and trains with the new config (typically lower LR + new data). |
| |
| Args: |
| model: The LLM model (same architecture as the pretrained checkpoint). |
| train_dataloader: Mixed data dataloader (code + general). |
| val_dataloader: Validation dataloader. |
| config: CPT training config. Defaults to TrainConfig.code_cpt_1b(). |
| pretrained_checkpoint_dir: Path to the base pretrained checkpoint directory. |
| seq_len: Sequence length (should match model config). |
| auto_config: Whether to auto-detect GPU and adjust settings. |
| |
| Returns: |
| Trainer instance after training completes. |
| |
| Usage (Colab): |
| ```python |
| from llm_lab.config import ModelConfig, DataConfig, TrainConfig |
| from llm_lab.model import LLMModel |
| from llm_lab.data import setup_cpt_data_pipeline |
| from llm_lab.training import start_cpt |
| |
| model = LLMModel(ModelConfig.base_1b()) |
| tok, train_dl, val_dl = setup_cpt_data_pipeline() |
| trainer = start_cpt(model, train_dl, val_dl) |
| ``` |
| """ |
| |
| CheckpointManager.load_model_only( |
| model=model, |
| checkpoint_dir=pretrained_checkpoint_dir, |
| device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), |
| ) |
|
|
| |
| |
| |
| return _setup_and_train( |
| model=model, |
| train_dataloader=train_dataloader, |
| val_dataloader=val_dataloader, |
| config=config or TrainConfig.code_cpt_1b(), |
| seq_len=seq_len, |
| auto_config=auto_config, |
| fallback_checkpoint_dir="./checkpoints-cpt-code", |
| ) |
|
|