LLM-1B-Lab / llm_lab /training /runner.py
Vjeong's picture
Refactor runner.py: extract shared setup logic into _setup_and_train helper
9b6bd85
"""Training execution helper (Quick Start)."""
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from llm_lab.config import TrainConfig
from .trainer import Trainer
from .checkpoint import CheckpointManager
from llm_lab.utils import auto_configure, set_seed
def _setup_and_train(
model: nn.Module,
train_dataloader: DataLoader,
val_dataloader: Optional[DataLoader],
config: TrainConfig,
seq_len: int,
auto_config: bool,
fallback_checkpoint_dir: str,
) -> Trainer:
"""Shared setup logic: auto-config, Drive check, seed, train."""
if auto_config:
config = auto_configure(config)
# Check Google Drive mount (Colab)
if "/content/drive" in config.checkpoint_dir:
drive_path = Path("/content/drive/MyDrive")
if not drive_path.exists():
print("\n⚠️ Google Drive is not mounted!")
print(" Run in Colab: from google.colab import drive; drive.mount('/content/drive')")
print(" Switching to local path.")
config.checkpoint_dir = fallback_checkpoint_dir
set_seed(config.seed)
trainer = Trainer(model, train_dataloader, val_dataloader, config, seq_len)
trainer.train()
return trainer
def start_training(
model: nn.Module,
train_dataloader: DataLoader,
val_dataloader: Optional[DataLoader] = None,
config: Optional[TrainConfig] = None,
seq_len: int = 2048,
auto_config: bool = True,
) -> Trainer:
"""Starts training (one-line execution).
Usage (Colab):
```python
from model import LLMModel, ModelConfig
from data_pipeline import setup_data_pipeline, DataConfig
from trainer import start_training, TrainConfig
# 1. Create model
model_config = ModelConfig.base_1b()
model = LLMModel(model_config)
# 2. Data pipeline
tok, train_dl, val_dl = setup_data_pipeline("pretrained")
# 3. Start training (automatic checkpoint restoration)
trainer = start_training(model, train_dl, val_dl)
```
"""
return _setup_and_train(
model=model,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
config=config or TrainConfig(),
seq_len=seq_len,
auto_config=auto_config,
fallback_checkpoint_dir="./checkpoints",
)
def start_cpt(
model: nn.Module,
train_dataloader: DataLoader,
val_dataloader: Optional[DataLoader] = None,
config: Optional[TrainConfig] = None,
pretrained_checkpoint_dir: str = "/content/drive/MyDrive/llm-1b-lab/checkpoints",
seq_len: int = 2048,
auto_config: bool = True,
) -> Trainer:
"""Starts Continual Pre-Training from a pretrained checkpoint.
Loads pretrained model weights, creates a fresh optimizer,
and trains with the new config (typically lower LR + new data).
Args:
model: The LLM model (same architecture as the pretrained checkpoint).
train_dataloader: Mixed data dataloader (code + general).
val_dataloader: Validation dataloader.
config: CPT training config. Defaults to TrainConfig.code_cpt_1b().
pretrained_checkpoint_dir: Path to the base pretrained checkpoint directory.
seq_len: Sequence length (should match model config).
auto_config: Whether to auto-detect GPU and adjust settings.
Returns:
Trainer instance after training completes.
Usage (Colab):
```python
from llm_lab.config import ModelConfig, DataConfig, TrainConfig
from llm_lab.model import LLMModel
from llm_lab.data import setup_cpt_data_pipeline
from llm_lab.training import start_cpt
model = LLMModel(ModelConfig.base_1b())
tok, train_dl, val_dl = setup_cpt_data_pipeline()
trainer = start_cpt(model, train_dl, val_dl)
```
"""
# Load pretrained weights (model only, no optimizer)
CheckpointManager.load_model_only(
model=model,
checkpoint_dir=pretrained_checkpoint_dir,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
# Note: Trainer._try_resume() will look in the CPT checkpoint dir.
# If no CPT checkpoint exists yet, it starts from step 0 with
# the pretrained weights we just loaded.
return _setup_and_train(
model=model,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
config=config or TrainConfig.code_cpt_1b(),
seq_len=seq_len,
auto_config=auto_config,
fallback_checkpoint_dir="./checkpoints-cpt-code",
)