File size: 3,024 Bytes
30c14cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
---
layout: default
title: Training Guide
permalink: /training/
---

# Training Guide

This guide covers the three-stage training process in CLaRa.

## Overview

CLaRa uses a three-stage training approach:

1. **Stage 1**: Compression Pretraining
2. **Stage 2**: Compression Instruction Tuning
3. **Stage 3**: End-to-End Fine-tuning (CLaRa)

## Stage 1: Compression Pretraining

Train the compressor to learn effective document compression.

### Key Parameters

- `--stage stage1`: Training stage identifier
- `--compress_rate`: Compression rate (default: 32)
- `--doc_max_length`: Maximum document length (default: 256)
- `--mse_loss`: Use MSE loss for compression alignment
- `--qa_loss`: Use QA loss for semantic preservation

### Example Command

```bash
bash scripts/train_pretraining.sh
```

### Data Format

**Stage 1 Pretraining Data:**
```json
{
    "data_type": "qa",
    "question": ["Question 1", "Question 2", ...],
    "answers": ["Answer 1", "Answer 2", ...],
    "docs": ["Document 1", "Document 2", ...]
}
```

## Stage 2: Compression Instruction Tuning

Fine-tune the compressor on instruction-following tasks.

### Key Parameters

- `--stage stage1_2`: Training stage identifier
- `--pretrain_checkpoint`: Path to Stage 1 checkpoint
- `--generation_top_k`: Top-k sampling (default: 5)
- `--mse_loss`: Continue using MSE loss
- `--do_eval_gen`: Enable generation evaluation

### Example Command

```bash
bash scripts/train_instruction_tuning.sh
```

### Data Format

**Stage 2 Instruction Tuning Data:**
```json
{
    "question": "Single question text",
    "docs": ["Document 1", "Document 2", ...],
    "gold_answer": "Reference answer",
    "answer": "Generated answer"
}
```

## Stage 3: End-to-End Training

Jointly train reranker and generator with retrieval.

### Key Parameters

- `--stage stage2`: Training stage identifier
- `--pretrain_checkpoint`: Path to Stage 2 checkpoint
- `--generation_top_k`: Top-k sampling for generation
- `--do_eval_gen`: Enable generation evaluation

### Example Command

```bash
bash scripts/train_stage_end_to_end.sh
```

### Data Format

**Stage 3 End-to-End Data:**
```json
{
    "question": "Single question text",
    "docs": ["Document 1", "Document 2", ...],
    "gold_answer": "Reference answer"
}
```

## Distributed Training

All training stages support distributed training across multiple nodes and GPUs.

### Key Parameters

- `--max_len`: Maximum sequence length (2048 for stage1/stage2, 1024 for stage3)
- `--train_batch_size`: Training batch size
- `--micro_train_batch_size`: Micro batch size for gradient accumulation
- `--learning_rate`: Learning rate (1e-4 for stage1/stage2, 5e-6 for stage3)
- `--max_epochs`: Maximum training epochs
- `--zero_stage`: ZeRO optimization stage (default: 2)
- `--bf16`: Use bfloat16 precision
- `--flash_attn`: Use Flash Attention 2

## Monitoring Training

Training progress is logged via:
- Console output
- Wandb (if configured)
- Checkpoint files

Checkpoints are saved at the path specified by `--save_path`.