unesco-data-ai commited on
Commit
b133a83
·
verified ·
1 Parent(s): 1424572

Upload train_unesco_tagger.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_unesco_tagger.py +69 -0
train_unesco_tagger.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "trackio",
8
+ # ]
9
+ # ///
10
+
11
+ from datasets import load_dataset
12
+ from peft import LoraConfig
13
+ from trl import SFTTrainer, SFTConfig
14
+
15
+ print("Loading dataset...")
16
+ dataset = load_dataset("unesco-data-ai/unesco-thesaurus-sft")
17
+ train_dataset = dataset["train"]
18
+ eval_dataset = dataset["validation"]
19
+
20
+ print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
21
+
22
+ config = SFTConfig(
23
+ output_dir="qwen2.5-3b-unesco-tagger",
24
+ push_to_hub=True,
25
+ hub_model_id="unesco-data-ai/qwen2.5-3b-unesco-tagger-v1",
26
+ hub_strategy="every_save",
27
+ num_train_epochs=3,
28
+ per_device_train_batch_size=2,
29
+ gradient_accumulation_steps=8,
30
+ learning_rate=2e-5,
31
+ max_length=2048,
32
+ logging_steps=10,
33
+ save_strategy="steps",
34
+ save_steps=200,
35
+ save_total_limit=2,
36
+ eval_strategy="steps",
37
+ eval_steps=200,
38
+ warmup_ratio=0.1,
39
+ lr_scheduler_type="cosine",
40
+ report_to="trackio",
41
+ project="unesco-keyword-extraction",
42
+ run_name="qwen2.5-3b-sft-v1",
43
+ )
44
+
45
+ peft_config = LoraConfig(
46
+ r=16,
47
+ lora_alpha=32,
48
+ lora_dropout=0.05,
49
+ bias="none",
50
+ task_type="CAUSAL_LM",
51
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
52
+ )
53
+
54
+ print("Initializing trainer...")
55
+ trainer = SFTTrainer(
56
+ model="Qwen/Qwen2.5-3B-Instruct",
57
+ train_dataset=train_dataset,
58
+ eval_dataset=eval_dataset,
59
+ args=config,
60
+ peft_config=peft_config,
61
+ )
62
+
63
+ print("Starting training...")
64
+ trainer.train()
65
+
66
+ print("Pushing to Hub...")
67
+ trainer.push_to_hub()
68
+
69
+ print("Complete!")