Caesaropapism commited on
Commit
47c9c47
·
verified ·
1 Parent(s): 0e7bb4a

Upload run_lora_training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_lora_training.py +211 -0
run_lora_training.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ QLoRA fine-tuning entry point for GraiLLM.
4
+
5
+ Designed for use on Google Colab, Kaggle, or Hugging Face free GPUs.
6
+ The script expects the dataset generated by `prepare_dataset.py` where each
7
+ record contains a chat-style `messages` list.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ from pathlib import Path
14
+ from typing import Dict, List
15
+
16
+ import torch
17
+ from datasets import load_dataset
18
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
19
+ from transformers import (
20
+ AutoModelForCausalLM,
21
+ AutoTokenizer,
22
+ DataCollatorForLanguageModeling,
23
+ TrainingArguments,
24
+ Trainer,
25
+ )
26
+
27
+
28
+ DEFAULT_BASE_MODEL = "openai/gpt-oss-mini-7b"
29
+
30
+
31
+ def parse_args() -> argparse.Namespace:
32
+ parser = argparse.ArgumentParser(description="Fine-tune GraiLLM with QLoRA.")
33
+ parser.add_argument(
34
+ "--train-file",
35
+ type=Path,
36
+ required=True,
37
+ help="Path to the JSONL training file produced by prepare_dataset.py.",
38
+ )
39
+ parser.add_argument(
40
+ "--eval-file",
41
+ type=Path,
42
+ required=True,
43
+ help="Path to the JSONL evaluation file produced by prepare_dataset.py.",
44
+ )
45
+ parser.add_argument(
46
+ "--base-model",
47
+ type=str,
48
+ default=DEFAULT_BASE_MODEL,
49
+ help="Base Hugging Face model ID to fine-tune (QLoRA friendly).",
50
+ )
51
+ parser.add_argument(
52
+ "--output-dir",
53
+ type=Path,
54
+ default=Path("outputs/graillm-lora"),
55
+ help="Directory where checkpoints and final adapters will be saved.",
56
+ )
57
+ parser.add_argument(
58
+ "--batch-size",
59
+ type=int,
60
+ default=16,
61
+ help="Micro batch size per device after gradient accumulation.",
62
+ )
63
+ parser.add_argument(
64
+ "--grad-accum",
65
+ type=int,
66
+ default=4,
67
+ help="Gradient accumulation steps.",
68
+ )
69
+ parser.add_argument(
70
+ "--epochs",
71
+ type=int,
72
+ default=3,
73
+ help="Number of training epochs.",
74
+ )
75
+ parser.add_argument(
76
+ "--lr",
77
+ type=float,
78
+ default=2e-4,
79
+ help="Learning rate.",
80
+ )
81
+ parser.add_argument("--max-steps", type=int, default=-1, help="Max training steps.")
82
+ parser.add_argument("--bf16", action="store_true", help="Enable bfloat16 training.")
83
+ parser.add_argument(
84
+ "--push-to-hub",
85
+ action="store_true",
86
+ help="Push the adapter weights to the active Hugging Face repo after training.",
87
+ )
88
+ parser.add_argument(
89
+ "--hub-model-id",
90
+ type=str,
91
+ default="dakotarainlock/GraiLLM-7B-Lora",
92
+ help="Target repository when --push-to-hub is supplied.",
93
+ )
94
+ return parser.parse_args()
95
+
96
+
97
+ def format_messages(messages: List[Dict[str, str]]) -> str:
98
+ """Convert a message list into a single training string."""
99
+ turns = []
100
+ for message in messages:
101
+ role = message["role"]
102
+ content = message["content"].strip()
103
+ if not content:
104
+ continue
105
+ if role == "system":
106
+ turns.append(f"<<SYS>>\n{content}\n<</SYS>>")
107
+ elif role == "user":
108
+ turns.append(f"[USER]\n{content}")
109
+ elif role == "assistant":
110
+ turns.append(f"[ASSISTANT]\n{content}")
111
+ return "\n\n".join(turns) + "\n"
112
+
113
+
114
+ def tokenize_batch(example: Dict[str, List[Dict[str, str]]], tokenizer: AutoTokenizer):
115
+ text = format_messages(example["messages"])
116
+ tokenized = tokenizer(
117
+ text,
118
+ truncation=True,
119
+ max_length=min(tokenizer.model_max_length, 2048),
120
+ padding=False,
121
+ )
122
+ tokenized["labels"] = tokenized["input_ids"].copy()
123
+ return tokenized
124
+
125
+
126
+ def main() -> None:
127
+ args = parse_args()
128
+ torch_dtype = torch.bfloat16 if args.bf16 else torch.float16
129
+
130
+ tokenizer = AutoTokenizer.from_pretrained(
131
+ args.base_model,
132
+ use_fast=True,
133
+ )
134
+ if tokenizer.pad_token is None:
135
+ tokenizer.pad_token = tokenizer.eos_token
136
+
137
+ model = AutoModelForCausalLM.from_pretrained(
138
+ args.base_model,
139
+ device_map="auto",
140
+ torch_dtype=torch_dtype,
141
+ load_in_4bit=True,
142
+ )
143
+
144
+ model = prepare_model_for_kbit_training(model)
145
+ peft_config = LoraConfig(
146
+ r=64,
147
+ lora_alpha=16,
148
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
149
+ lora_dropout=0.05,
150
+ bias="none",
151
+ task_type="CAUSAL_LM",
152
+ )
153
+ model = get_peft_model(model, peft_config)
154
+
155
+ dataset = load_dataset(
156
+ "json",
157
+ data_files={
158
+ "train": str(args.train_file),
159
+ "eval": str(args.eval_file),
160
+ },
161
+ )
162
+
163
+ tokenized_ds = dataset.map(
164
+ lambda ex: tokenize_batch(ex, tokenizer),
165
+ remove_columns=dataset["train"].column_names,
166
+ )
167
+
168
+ collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
169
+
170
+ training_args = TrainingArguments(
171
+ output_dir=str(args.output_dir),
172
+ num_train_epochs=args.epochs,
173
+ per_device_train_batch_size=max(1, args.batch_size // args.grad_accum),
174
+ per_device_eval_batch_size=max(1, args.batch_size // args.grad_accum),
175
+ gradient_accumulation_steps=args.grad_accum,
176
+ learning_rate=args.lr,
177
+ fp16=not args.bf16,
178
+ bf16=args.bf16,
179
+ logging_steps=10,
180
+ evaluation_strategy="steps",
181
+ eval_steps=50,
182
+ save_strategy="steps",
183
+ save_steps=100,
184
+ save_total_limit=3,
185
+ warmup_ratio=0.03,
186
+ lr_scheduler_type="cosine",
187
+ report_to="tensorboard",
188
+ max_steps=args.max_steps,
189
+ push_to_hub=args.push_to_hub,
190
+ hub_model_id=args.hub_model_id if args.push_to_hub else None,
191
+ )
192
+
193
+ trainer = Trainer(
194
+ model=model,
195
+ tokenizer=tokenizer,
196
+ args=training_args,
197
+ train_dataset=tokenized_ds["train"],
198
+ eval_dataset=tokenized_ds["eval"],
199
+ data_collator=collator,
200
+ )
201
+
202
+ trainer.train()
203
+ trainer.save_model()
204
+ tokenizer.save_pretrained(args.output_dir)
205
+
206
+ if args.push_to_hub:
207
+ trainer.push_to_hub()
208
+
209
+
210
+ if __name__ == "__main__":
211
+ main()