metadata
layout: default
title: Training Guide
permalink: /training/
Training Guide
This guide covers the three-stage training process in CLaRa.
Overview
CLaRa uses a three-stage training approach:
- Stage 1: Compression Pretraining
- Stage 2: Compression Instruction Tuning
- Stage 3: End-to-End Fine-tuning (CLaRa)
Stage 1: Compression Pretraining
Train the compressor to learn effective document compression.
Key Parameters
--stage stage1: Training stage identifier--compress_rate: Compression rate (default: 32)--doc_max_length: Maximum document length (default: 256)--mse_loss: Use MSE loss for compression alignment--qa_loss: Use QA loss for semantic preservation
Example Command
bash scripts/train_pretraining.sh
Data Format
Stage 1 Pretraining Data:
{
"data_type": "qa",
"question": ["Question 1", "Question 2", ...],
"answers": ["Answer 1", "Answer 2", ...],
"docs": ["Document 1", "Document 2", ...]
}
Stage 2: Compression Instruction Tuning
Fine-tune the compressor on instruction-following tasks.
Key Parameters
--stage stage1_2: Training stage identifier--pretrain_checkpoint: Path to Stage 1 checkpoint--generation_top_k: Top-k sampling (default: 5)--mse_loss: Continue using MSE loss--do_eval_gen: Enable generation evaluation
Example Command
bash scripts/train_instruction_tuning.sh
Data Format
Stage 2 Instruction Tuning Data:
{
"question": "Single question text",
"docs": ["Document 1", "Document 2", ...],
"gold_answer": "Reference answer",
"answer": "Generated answer"
}
Stage 3: End-to-End Training
Jointly train reranker and generator with retrieval.
Key Parameters
--stage stage2: Training stage identifier--pretrain_checkpoint: Path to Stage 2 checkpoint--generation_top_k: Top-k sampling for generation--do_eval_gen: Enable generation evaluation
Example Command
bash scripts/train_stage_end_to_end.sh
Data Format
Stage 3 End-to-End Data:
{
"question": "Single question text",
"docs": ["Document 1", "Document 2", ...],
"gold_answer": "Reference answer"
}
Distributed Training
All training stages support distributed training across multiple nodes and GPUs.
Key Parameters
--max_len: Maximum sequence length (2048 for stage1/stage2, 1024 for stage3)--train_batch_size: Training batch size--micro_train_batch_size: Micro batch size for gradient accumulation--learning_rate: Learning rate (1e-4 for stage1/stage2, 5e-6 for stage3)--max_epochs: Maximum training epochs--zero_stage: ZeRO optimization stage (default: 2)--bf16: Use bfloat16 precision--flash_attn: Use Flash Attention 2
Monitoring Training
Training progress is logged via:
- Console output
- Wandb (if configured)
- Checkpoint files
Checkpoints are saved at the path specified by --save_path.