ishwarraja commited on
Commit
63a0be4
·
verified ·
1 Parent(s): 4e8c3fc

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +30 -0
  2. train.py +255 -0
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # # 1) Create project
3
+ # mkdir phi2-grpo-qlora && cd phi2-grpo-qlora
4
+
5
+ # # 2) Create & activate a virtual environment (choose one)
6
+ # python -m venv .venv && source .venv/bin/activate
7
+ # # or: conda create -n phi2-grpo python=3.10 -y && conda activate phi2-grpo
8
+
9
+ # # 3) Install core deps (pin to mature versions that work well together)
10
+ # pip install -U "transformers>=4. Forty" accelerate datasets peft bitsandbytes trl gradio
11
+ # # If your GPU supports bfloat16 well, also:
12
+ # pip install torch --index-url https://download.pytorch.org/whl/cu121
13
+
14
+ # # 4) Optional (to log in to HF Hub for pushing adapters later)
15
+ # pip install -U huggingface_hub
16
+
17
+
18
+
19
+ # Core
20
+ transformers>=4.0
21
+ accelerate
22
+ datasets
23
+ peft
24
+ trl
25
+ bitsandbytes
26
+ gradio
27
+
28
+ # If logs show CUDA wheel mismatch, uncomment and adjust per Spaces GPU doc:
29
+ # --extra-index-url https://download.pytorch.org/whl/cu121
30
+ # torch==2.3.1
train.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # train.py
3
+ import os
4
+ import math
5
+ import warnings
6
+ from dataclasses import dataclass
7
+ from typing import List, Dict, Any
8
+
9
+ import torch
10
+ from datasets import load_dataset
11
+ from transformers import (
12
+ AutoTokenizer,
13
+ AutoModelForCausalLM,
14
+ BitsAndBytesConfig,
15
+ )
16
+ from peft import LoraConfig, get_peft_model
17
+ from trl import GRPOTrainer, GRPOConfig
18
+
19
+ # ---------------------------
20
+ # 0) Basic config (edit here)
21
+ # ---------------------------
22
+ MODEL_ID = "microsoft/phi-2" # base model
23
+ OUTPUT_DIR = "./runs/phi2-grpo-qlora" # logs + checkpoints
24
+ ADAPTER_DIR = "./adapters" # where LoRA adapters are saved
25
+ HF_DATASET = "OpenAssistant/oasst1" # dataset id
26
+
27
+ # Token lengths: keep well within Phi-2 context (2048)
28
+ # Prompt 1536 + completion 256 leaves headroom for BOS/EOS & formatting
29
+ MAX_PROMPT_LEN = 1536
30
+ MAX_COMPLETION_LEN = 256
31
+
32
+ # GRPO shape parameters (we'll sanity-check these below)
33
+ NUM_GENERATIONS = 4 # completions per prompt
34
+ PER_DEVICE_TRAIN_BS = 4 # "effective per-device" batch used by GRPO
35
+ GENERATION_BATCH_SIZE = 4 # how many sequences we generate at once
36
+
37
+ LEARNING_RATE = 5e-6
38
+ NUM_EPOCHS = 1
39
+ LOG_STEPS = 5
40
+ SAVE_STEPS = 200
41
+
42
+ # ---------------------------
43
+ # 1) Utilities
44
+ # ---------------------------
45
+ def has_gpu() -> bool:
46
+ return torch.cuda.is_available()
47
+
48
+ def suggest_divisible_values(num_processes: int, per_device_bs: int, limit: int = 16) -> List[int]:
49
+ """Suggest valid num_generations values dividing the global train batch size."""
50
+ global_bs = num_processes * per_device_bs
51
+ return [g for g in range(1, limit + 1) if global_bs % g == 0]
52
+
53
+ def ensure_divisibility_or_die(num_processes: int):
54
+ """
55
+ Validate GRPO constraints to avoid:
56
+ - ValueError: global train batch size must be divisible by num_generations
57
+ - ValueError: generation_batch_size must be divisible by num_generations
58
+ """
59
+ global_bs = num_processes * PER_DEVICE_TRAIN_BS
60
+ ok1 = (global_bs % NUM_GENERATIONS == 0)
61
+ ok2 = (GENERATION_BATCH_SIZE % NUM_GENERATIONS == 0)
62
+ if ok1 and ok2:
63
+ return
64
+ msg = []
65
+ if not ok1:
66
+ vals = suggest_divisible_values(num_processes, PER_DEVICE_TRAIN_BS, limit=64)
67
+ msg.append(
68
+ f"- With num_processes={num_processes} and per_device_train_batch_size={PER_DEVICE_TRAIN_BS}, "
69
+ f"the global train batch size is {global_bs}; choose NUM_GENERATIONS ∈ {vals}."
70
+ )
71
+ if not ok2:
72
+ # suggest the next multiple of NUM_GENERATIONS
73
+ next_mult = (GENERATION_BATCH_SIZE // NUM_GENERATIONS + 1) * NUM_GENERATIONS
74
+ msg.append(
75
+ f"- Set GENERATION_BATCH_SIZE to a multiple of NUM_GENERATIONS={NUM_GENERATIONS} "
76
+ f"(e.g., {next_mult})."
77
+ )
78
+ hint = "\n".join(msg)
79
+ raise ValueError(
80
+ "Invalid GRPO batching parameters.\n"
81
+ + hint
82
+ + "\n(Constraint documented in TRL’s GRPOTrainer.)"
83
+ )
84
+
85
+ # ---------------------------
86
+ # 2) Rewards
87
+ # ---------------------------
88
+ def reward_format(completions: List[List[Dict[str, str]]], **kwargs) -> List[float]:
89
+ """
90
+ Reward if the model produces a non-empty assistant message that ends with punctuation.
91
+ `completions` is list of conversations; each completion is a list of messages:
92
+ [{"role": "assistant", "content": "..."}]
93
+ """
94
+ rewards = []
95
+ for completion in completions:
96
+ text = completion[0]["content"].strip() if completion and completion[0].get("content") else ""
97
+ ok = (len(text) > 10) and (text[-1] in ".!?")
98
+ rewards.append(1.0 if ok else 0.0)
99
+ return rewards
100
+
101
+ def reward_length(completions: List[List[Dict[str, str]]], **kwargs) -> List[float]:
102
+ """
103
+ Reward completions whose token length is in a 'goldilocks' range [64, 256].
104
+ If tokenizer is available in kwargs, use it for accurate token counts.
105
+ """
106
+ tok = kwargs.get("tokenizer", None)
107
+ lo, hi = 64, 256
108
+ scores = []
109
+ for completion in completions:
110
+ text = completion[0]["content"] if completion and completion[0].get("content") else ""
111
+ if not text:
112
+ scores.append(0.0)
113
+ continue
114
+ length = len(tok.encode(text)) if tok else len(text.split())
115
+ if lo <= length <= hi:
116
+ scores.append(1.0)
117
+ else:
118
+ # soft ramp: distance from range normalized
119
+ d = 0 if lo <= length <= hi else min(abs(length - lo), abs(length - hi))
120
+ scores.append(max(0.0, 1.0 - d / hi))
121
+ return scores
122
+
123
+ # ---------------------------
124
+ # 3) Load tokenizer & model (4-bit with CPU fallback)
125
+ # ---------------------------
126
+ def load_tokenizer_and_model():
127
+ # Tokenizer
128
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
129
+ if tokenizer.pad_token is None:
130
+ tokenizer.pad_token = tokenizer.eos_token
131
+
132
+ # 4-bit config for QLoRA
133
+ bnb_config = BitsAndBytesConfig(
134
+ load_in_4bit=True,
135
+ bnb_4bit_quant_type="nf4",
136
+ bnb_4bit_compute_dtype=torch.bfloat16 if has_gpu() else torch.float32,
137
+ bnb_4bit_use_double_quant=True,
138
+ )
139
+
140
+ device_map = "auto" if has_gpu() else {"": "cpu"}
141
+ try:
142
+ model = AutoModelForCausalLM.from_pretrained(
143
+ MODEL_ID,
144
+ quantization_config=bnb_config,
145
+ device_map=device_map,
146
+ trust_remote_code=False,
147
+ )
148
+ except Exception as e:
149
+ # CPU fallback without quantization
150
+ print(f"[WARN] 4-bit load failed ({e}); falling back to CPU fp32.")
151
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map={"": "cpu"})
152
+
153
+ # Disable cache for training; enable gradient checkpointing if GPU-limited
154
+ model.config.use_cache = False
155
+ if has_gpu():
156
+ model.gradient_checkpointing_enable()
157
+
158
+ # QLoRA: target q/k/v projections for Phi-2
159
+ peft_config = LoraConfig(
160
+ r=16, lora_alpha=32, lora_dropout=0.05,
161
+ bias="none", task_type="CAUSAL_LM",
162
+ target_modules=["q_proj", "k_proj", "v_proj"], # Phi-2 attention projections
163
+ )
164
+ model = get_peft_model(model, peft_config)
165
+ model.print_trainable_parameters()
166
+
167
+ return tokenizer, model
168
+
169
+ # ---------------------------
170
+ # 4) Load & clean dataset (English-only) and build prompts
171
+ # ---------------------------
172
+ def load_dataset_oasst1(tokenizer):
173
+ ds = load_dataset(HF_DATASET, split="train")
174
+ # Keep only English rows; strip to columns we need to prevent KeyErrors
175
+ ds = ds.filter(lambda x: x.get("lang", None) == "en")
176
+ keep_cols = {"text", "role", "message_id", "parent_id", "message_tree_id", "lang"}
177
+ drop_cols = [c for c in ds.column_names if c not in keep_cols]
178
+ ds = ds.remove_columns(drop_cols)
179
+
180
+ # Build single-turn "chat" prompts for GRPO: list of messages with role/content.
181
+ # We’ll keep only "prompter" -> user prompts.
182
+ prompts = []
183
+ for rec in ds:
184
+ if rec.get("role") == "prompter":
185
+ content = rec.get("text", "").strip()
186
+ if not content:
187
+ continue
188
+ # minimal chat turn
189
+ messages = [{"role": "user", "content": content}]
190
+ # Convert to a single string prompt using a generic chat template (tokenizer may have one)
191
+ if hasattr(tokenizer, "apply_chat_template"):
192
+ prompt_str = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
193
+ else:
194
+ # fallback plain text
195
+ prompt_str = "User: " + content + "\nAssistant:"
196
+
197
+ prompts.append({"prompt": prompt_str})
198
+
199
+ # Keep it as a simple dict dataset for TRL
200
+ return prompts
201
+
202
+ # ---------------------------
203
+ # 5) Main
204
+ # ---------------------------
205
+ def main():
206
+ warnings.filterwarnings("default") # show potential FutureWarnings for max_prompt_length evolution
207
+
208
+ tokenizer, model = load_tokenizer_and_model()
209
+
210
+ # Prepare dataset entries like {"prompt": "..."} as TRL suggests
211
+ train_dataset = load_dataset_oasst1(tokenizer)
212
+
213
+ # Validate GRPO divisibility constraints (avoid the common ValueError)
214
+ num_processes = int(os.environ.get("WORLD_SIZE", "1"))
215
+ ensure_divisibility_or_die(num_processes)
216
+
217
+ # GRPO training args
218
+ args = GRPOConfig(
219
+ output_dir=OUTPUT_DIR,
220
+ num_train_epochs=NUM_EPOCHS,
221
+ learning_rate=LEARNING_RATE,
222
+ logging_steps=LOG_STEPS,
223
+ save_steps=SAVE_STEPS,
224
+ save_total_limit=2,
225
+ bf16=has_gpu(), # prefer bf16 when available
226
+ per_device_train_batch_size=PER_DEVICE_TRAIN_BS,
227
+ generation_batch_size=GENERATION_BATCH_SIZE,
228
+ num_generations=NUM_GENERATIONS,
229
+ max_prompt_length=MAX_PROMPT_LEN, # keep below context limit
230
+ max_completion_length=MAX_COMPLETION_LEN,
231
+ gradient_accumulation_steps=1,
232
+ report_to="none",
233
+ disable_dropout=True, # stabilizes GRPO per TRL notes
234
+ )
235
+
236
+ # Combine our two reward functions (equal weights)
237
+ reward_funcs = [reward_format, reward_length]
238
+
239
+ trainer = GRPOTrainer(
240
+ model=model,
241
+ args=args,
242
+ reward_funcs=reward_funcs,
243
+ train_dataset=train_dataset,
244
+ tokenizer=tokenizer, # passed to reward funcs via kwargs
245
+ )
246
+
247
+ trainer.train()
248
+
249
+ # Save ONLY the adapters (PEFT)
250
+ os.makedirs(ADAPTER_DIR, exist_ok=True)
251
+ trainer.model.save_pretrained(ADAPTER_DIR)
252
+ print(f"[OK] LoRA adapters saved to: {ADAPTER_DIR}")
253
+
254
+ if __name__ == "__main__":
255
+ main()