kingjux commited on
Commit
4dbdabc
·
verified ·
1 Parent(s): 8de7cf8

Upload train_ffmpeg.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_ffmpeg.py +74 -0
train_ffmpeg.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "transformers", "datasets", "accelerate", "bitsandbytes"]
3
+ # ///
4
+
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig
7
+ from trl import SFTTrainer, SFTConfig
8
+ import trackio
9
+
10
+ # Load the dataset
11
+ dataset = load_dataset("kingjux/ffmpeg-commands-cot", split="train")
12
+ print(f"Loaded {len(dataset)} training examples")
13
+
14
+ # LoRA config for efficient fine-tuning
15
+ peft_config = LoraConfig(
16
+ r=16,
17
+ lora_alpha=32,
18
+ lora_dropout=0.05,
19
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
20
+ bias="none",
21
+ task_type="CAUSAL_LM",
22
+ )
23
+
24
+ # Training config
25
+ training_args = SFTConfig(
26
+ output_dir="ffmpeg-command-generator",
27
+
28
+ # Training params
29
+ num_train_epochs=3,
30
+ per_device_train_batch_size=2,
31
+ gradient_accumulation_steps=4,
32
+ learning_rate=2e-4,
33
+ warmup_ratio=0.1,
34
+
35
+ # Logging and saving
36
+ logging_steps=5,
37
+ save_strategy="epoch",
38
+
39
+ # Hub settings
40
+ push_to_hub=True,
41
+ hub_model_id="kingjux/ffmpeg-command-generator",
42
+ hub_strategy="every_save",
43
+
44
+ # Trackio monitoring
45
+ report_to="trackio",
46
+ run_name="ffmpeg-sft-30examples",
47
+
48
+ # Memory optimization
49
+ gradient_checkpointing=True,
50
+ bf16=True,
51
+
52
+ # Other
53
+ seed=42,
54
+ max_seq_length=1024,
55
+ )
56
+
57
+ # Create trainer
58
+ trainer = SFTTrainer(
59
+ model="Qwen/Qwen2.5-0.5B-Instruct",
60
+ train_dataset=dataset,
61
+ peft_config=peft_config,
62
+ args=training_args,
63
+ )
64
+
65
+ # Train
66
+ print("Starting training...")
67
+ trainer.train()
68
+
69
+ # Save and push
70
+ print("Pushing to Hub...")
71
+ trainer.save_model()
72
+ trainer.push_to_hub()
73
+
74
+ print("Training complete!")