ml-clara / docs /training.md
dl3239491's picture
Upload folder using huggingface_hub
30c14cd verified
metadata
layout: default
title: Training Guide
permalink: /training/

Training Guide

This guide covers the three-stage training process in CLaRa.

Overview

CLaRa uses a three-stage training approach:

  1. Stage 1: Compression Pretraining
  2. Stage 2: Compression Instruction Tuning
  3. Stage 3: End-to-End Fine-tuning (CLaRa)

Stage 1: Compression Pretraining

Train the compressor to learn effective document compression.

Key Parameters

  • --stage stage1: Training stage identifier
  • --compress_rate: Compression rate (default: 32)
  • --doc_max_length: Maximum document length (default: 256)
  • --mse_loss: Use MSE loss for compression alignment
  • --qa_loss: Use QA loss for semantic preservation

Example Command

bash scripts/train_pretraining.sh

Data Format

Stage 1 Pretraining Data:

{
    "data_type": "qa",
    "question": ["Question 1", "Question 2", ...],
    "answers": ["Answer 1", "Answer 2", ...],
    "docs": ["Document 1", "Document 2", ...]
}

Stage 2: Compression Instruction Tuning

Fine-tune the compressor on instruction-following tasks.

Key Parameters

  • --stage stage1_2: Training stage identifier
  • --pretrain_checkpoint: Path to Stage 1 checkpoint
  • --generation_top_k: Top-k sampling (default: 5)
  • --mse_loss: Continue using MSE loss
  • --do_eval_gen: Enable generation evaluation

Example Command

bash scripts/train_instruction_tuning.sh

Data Format

Stage 2 Instruction Tuning Data:

{
    "question": "Single question text",
    "docs": ["Document 1", "Document 2", ...],
    "gold_answer": "Reference answer",
    "answer": "Generated answer"
}

Stage 3: End-to-End Training

Jointly train reranker and generator with retrieval.

Key Parameters

  • --stage stage2: Training stage identifier
  • --pretrain_checkpoint: Path to Stage 2 checkpoint
  • --generation_top_k: Top-k sampling for generation
  • --do_eval_gen: Enable generation evaluation

Example Command

bash scripts/train_stage_end_to_end.sh

Data Format

Stage 3 End-to-End Data:

{
    "question": "Single question text",
    "docs": ["Document 1", "Document 2", ...],
    "gold_answer": "Reference answer"
}

Distributed Training

All training stages support distributed training across multiple nodes and GPUs.

Key Parameters

  • --max_len: Maximum sequence length (2048 for stage1/stage2, 1024 for stage3)
  • --train_batch_size: Training batch size
  • --micro_train_batch_size: Micro batch size for gradient accumulation
  • --learning_rate: Learning rate (1e-4 for stage1/stage2, 5e-6 for stage3)
  • --max_epochs: Maximum training epochs
  • --zero_stage: ZeRO optimization stage (default: 2)
  • --bf16: Use bfloat16 precision
  • --flash_attn: Use Flash Attention 2

Monitoring Training

Training progress is logged via:

  • Console output
  • Wandb (if configured)
  • Checkpoint files

Checkpoints are saved at the path specified by --save_path.