Jabuszko commited on
Commit
730b6df
·
verified ·
1 Parent(s): eca7f4c

Create train_dpo.py

Browse files
Files changed (1) hide show
  1. train_dpo.py +48 -0
train_dpo.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
3
+ from trl import DPOTrainer
4
+
5
+ # --------------------
6
+ # 1. USTAW MODEL
7
+ # --------------------
8
+ model_name = "mistralai/Mistral-7B-Instruct-v0.2" # możesz zmienić
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_name,
13
+ load_in_4bit=True,
14
+ device_map="auto"
15
+ )
16
+
17
+ # --------------------
18
+ # 2. WCZYTAJ DANE
19
+ # --------------------
20
+ dataset = load_dataset("json", data_files="dpo_data.jsonl")["train"]
21
+
22
+ # --------------------
23
+ # 3. ARGUMENTY TRENINGU
24
+ # --------------------
25
+ training_args = TrainingArguments(
26
+ output_dir="./dpo_output",
27
+ per_device_train_batch_size=2,
28
+ gradient_accumulation_steps=4,
29
+ num_train_epochs=2,
30
+ learning_rate=5e-6,
31
+ bf16=True,
32
+ logging_steps=10,
33
+ save_steps=500,
34
+ )
35
+
36
+ # --------------------
37
+ # 4. START DPO
38
+ # --------------------
39
+ trainer = DPOTrainer(
40
+ model=model,
41
+ ref_model=None,
42
+ tokenizer=tokenizer,
43
+ train_dataset=dataset,
44
+ beta=0.1,
45
+ args=training_args,
46
+ )
47
+
48
+ trainer.train()