ceperaltab commited on
Commit
65ed589
Β·
verified Β·
1 Parent(s): 7ba7e95

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +142 -0
train.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = [
5
+ # "trl>=0.12.0",
6
+ # "peft>=0.7.0",
7
+ # "transformers>=4.44.2",
8
+ # "accelerate>=0.24.0",
9
+ # "bitsandbytes>=0.41.0",
10
+ # "datasets",
11
+ # "scipy",
12
+ # "hf_transfer",
13
+ # "rich",
14
+ # "trackio",
15
+ # ]
16
+ # ///
17
+
18
+ import torch
19
+ import os
20
+ import trackio
21
+ from datasets import load_dataset
22
+ from peft import LoraConfig
23
+ from trl import SFTTrainer, SFTConfig
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
25
+
26
+ # === CONFIGURATION ===
27
+ MODEL_NAME = "Qwen/Qwen3-8B" # Base model β€” fits locally on M2 Pro 16GB after fine-tuning
28
+ DATASET_NAME = "ceperaltab/diamond-vision-dataset"
29
+ OUTPUT_DIR = "diamond-vision-expert"
30
+ HF_USERNAME = "ceperaltab"
31
+
32
+ def main():
33
+ print("=" * 60)
34
+ print("Diamond Vision Expert β€” QLoRA Fine-tuning")
35
+ print(f"Base model : {MODEL_NAME}")
36
+ print(f"Dataset : {DATASET_NAME}")
37
+ print("=" * 60)
38
+
39
+ # Load dataset
40
+ print(f"πŸ“¦ Loading dataset: {DATASET_NAME}...")
41
+ dataset = load_dataset(DATASET_NAME, split="train")
42
+ print(f"βœ… Dataset loaded: {len(dataset)} examples")
43
+
44
+ # Train / eval split
45
+ print("πŸ”€ Creating train/eval split...")
46
+ dataset_split = dataset.train_test_split(test_size=0.05, seed=42)
47
+ train_dataset = dataset_split["train"]
48
+ eval_dataset = dataset_split["test"]
49
+ print(f" Train: {len(train_dataset)} | Eval: {len(eval_dataset)}")
50
+
51
+ # Training config
52
+ config = SFTConfig(
53
+ output_dir=OUTPUT_DIR,
54
+ push_to_hub=True,
55
+ hub_model_id=f"{HF_USERNAME}/{OUTPUT_DIR}",
56
+ hub_strategy="every_save",
57
+
58
+ # Training
59
+ num_train_epochs=1,
60
+ per_device_train_batch_size=1,
61
+ gradient_accumulation_steps=8,
62
+ learning_rate=2e-4,
63
+ max_seq_length=2048, # CV code is verbose β€” larger than default
64
+
65
+ # Logging & checkpointing
66
+ logging_steps=10,
67
+ save_strategy="steps",
68
+ save_steps=500,
69
+ save_total_limit=2,
70
+
71
+ # Evaluation
72
+ eval_strategy="steps",
73
+ eval_steps=500,
74
+
75
+ # Optimization
76
+ warmup_ratio=0.03,
77
+ lr_scheduler_type="cosine",
78
+ gradient_checkpointing=True,
79
+ bf16=True, # A10G supports bf16
80
+
81
+ # Monitoring
82
+ report_to="trackio",
83
+ project="diamond-vision-training",
84
+ run_name="diamond-vision-qwen3-8b-v1",
85
+ )
86
+
87
+ # LoRA
88
+ peft_config = LoraConfig(
89
+ r=64,
90
+ lora_alpha=16,
91
+ lora_dropout=0.1,
92
+ bias="none",
93
+ task_type="CAUSAL_LM",
94
+ target_modules=[
95
+ "q_proj", "k_proj", "v_proj", "o_proj",
96
+ "gate_proj", "up_proj", "down_proj",
97
+ ],
98
+ )
99
+
100
+ # 4-bit QLoRA quantization
101
+ bnb_config = BitsAndBytesConfig(
102
+ load_in_4bit=True,
103
+ bnb_4bit_quant_type="nf4",
104
+ bnb_4bit_compute_dtype=torch.bfloat16,
105
+ bnb_4bit_use_double_quant=True,
106
+ )
107
+
108
+ # Load model
109
+ print(f"πŸ”„ Loading base model: {MODEL_NAME}...")
110
+ model = AutoModelForCausalLM.from_pretrained(
111
+ MODEL_NAME,
112
+ quantization_config=bnb_config,
113
+ device_map="auto",
114
+ trust_remote_code=True,
115
+ )
116
+
117
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
118
+ tokenizer.pad_token = tokenizer.eos_token
119
+ tokenizer.padding_side = "right"
120
+
121
+ # Train
122
+ print("🎯 Initializing trainer...")
123
+ trainer = SFTTrainer(
124
+ model=model,
125
+ processing_class=tokenizer,
126
+ train_dataset=train_dataset,
127
+ eval_dataset=eval_dataset,
128
+ args=config,
129
+ peft_config=peft_config,
130
+ )
131
+
132
+ print("πŸš€ Starting training...")
133
+ trainer.train()
134
+
135
+ print("πŸ’Ύ Pushing final adapter to Hub...")
136
+ trainer.push_to_hub()
137
+
138
+ trackio.finish()
139
+ print("βœ… Done! Adapter pushed to:", f"https://huggingface.co/{HF_USERNAME}/{OUTPUT_DIR}")
140
+
141
+ if __name__ == "__main__":
142
+ main()