sunkencity commited on
Commit
3e19754
·
verified ·
1 Parent(s): f33bf5c

Upload train_aviation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_aviation.py +124 -0
train_aviation.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "trackio",
8
+ # "bitsandbytes",
9
+ # "scipy",
10
+ # "flash-attn"
11
+ # ]
12
+ # ///
13
+
14
+ import trackio
15
+ import torch
16
+ from datasets import load_dataset
17
+ from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
18
+ from trl import SFTTrainer, SFTConfig
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
20
+
21
+ # Model ID
22
+ model_id = "mistralai/Mistral-3-14B-Reasoning-2512"
23
+
24
+ # Load dataset
25
+ print("📦 Loading dataset...")
26
+ dataset = load_dataset("sakharamg/AviationQA", split="train")
27
+
28
+ # Limit dataset size for reasonable training time (e.g., 10k examples)
29
+ # 1M rows is too large for a single generic fine-tuning job without massive compute.
30
+ print("✂️ Subsampling dataset to 10,000 examples for efficiency...")
31
+ dataset = dataset.shuffle(seed=42).select(range(10000))
32
+
33
+ # Map to chat format
34
+ print("🔄 Mapping dataset...")
35
+ def to_messages(example):
36
+ return {
37
+ "messages": [
38
+ {"role": "user", "content": example["Question"]},
39
+ {"role": "assistant", "content": example["Answer"]}
40
+ ]
41
+ }
42
+ dataset = dataset.map(to_messages, remove_columns=dataset.column_names)
43
+
44
+ # Split
45
+ print("🔀 Creating train/eval split...")
46
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
47
+ train_dataset = dataset_split["train"]
48
+ eval_dataset = dataset_split["test"]
49
+
50
+ # Quantization Config (4-bit for memory efficiency)
51
+ bnb_config = BitsAndBytesConfig(
52
+ load_in_4bit=True,
53
+ bnb_4bit_quant_type="nf4",
54
+ bnb_4bit_compute_dtype=torch.bfloat16,
55
+ bnb_4bit_use_double_quant=True,
56
+ )
57
+
58
+ # Load Model
59
+ print(f"🤖 Loading model {model_id}...")
60
+ model = AutoModelForCausalLM.from_pretrained(
61
+ model_id,
62
+ quantization_config=bnb_config,
63
+ device_map="auto",
64
+ torch_dtype=torch.bfloat16,
65
+ attn_implementation="flash_attention_2"
66
+ )
67
+ model = prepare_model_for_kbit_training(model)
68
+
69
+ # Tokenizer
70
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
71
+ tokenizer.pad_token = tokenizer.eos_token
72
+ # Fix for some models that miss chat_template or padding
73
+ if tokenizer.chat_template is None:
74
+ tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
75
+
76
+ # LoRA Config
77
+ peft_config = LoraConfig(
78
+ r=16,
79
+ lora_alpha=32,
80
+ lora_dropout=0.05,
81
+ bias="none",
82
+ task_type="CAUSAL_LM",
83
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
84
+ )
85
+
86
+ # Training Config
87
+ config = SFTConfig(
88
+ output_dir="Mistral-3-14B-AviationQA-SFT",
89
+ push_to_hub=True,
90
+ hub_model_id="sunkencity/Mistral-3-14B-AviationQA-SFT",
91
+ hub_strategy="every_save",
92
+ num_train_epochs=1,
93
+ per_device_train_batch_size=4,
94
+ gradient_accumulation_steps=4,
95
+ learning_rate=2e-4,
96
+ fp16=False,
97
+ bf16=True,
98
+ logging_steps=10,
99
+ save_strategy="steps",
100
+ save_steps=100,
101
+ eval_strategy="steps",
102
+ eval_steps=100,
103
+ report_to="trackio",
104
+ project="aviation-qa-tuning",
105
+ run_name="mistral-14b-sft-v1",
106
+ max_seq_length=2048,
107
+ dataset_kwargs={"add_special_tokens": False} # Let tokenizer handle chat template
108
+ )
109
+
110
+ # Trainer
111
+ trainer = SFTTrainer(
112
+ model=model,
113
+ train_dataset=train_dataset,
114
+ eval_dataset=eval_dataset,
115
+ args=config,
116
+ peft_config=peft_config,
117
+ tokenizer=tokenizer,
118
+ )
119
+
120
+ print("🚀 Starting training...")
121
+ trainer.train()
122
+
123
+ print("💾 Pushing to Hub...")
124
+ trainer.push_to_hub()