lokegud commited on
Commit
e079765
·
verified ·
1 Parent(s): a40b72b

Upload train_infrastructure_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_infrastructure_model.py +186 -0
train_infrastructure_model.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.38.0",
7
+ # "datasets>=2.16.0",
8
+ # "torch>=2.1.0",
9
+ # "accelerate>=0.26.0",
10
+ # "bitsandbytes>=0.42.0",
11
+ # "trackio>=0.3.0"
12
+ # ]
13
+ # ///
14
+
15
+ """
16
+ Infrastructure Security Training - SFT Fine-tuning
17
+ Trains Qwen 2.5 7B on infrastructure management tasks
18
+ """
19
+
20
+ from datasets import load_dataset
21
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
22
+ from peft import LoraConfig
23
+ from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM
24
+ import torch
25
+ import trackio
26
+
27
+ # Model and dataset configuration
28
+ BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
29
+ DATASET_NAME = "lokegud/infrastructure-security-training"
30
+ OUTPUT_MODEL = "lokegud/infrastructure-assistant-7b"
31
+
32
+ print("=" * 60)
33
+ print("Infrastructure Assistant Training")
34
+ print("=" * 60)
35
+ print(f"Base Model: {BASE_MODEL}")
36
+ print(f"Dataset: {DATASET_NAME}")
37
+ print(f"Output: {OUTPUT_MODEL}")
38
+ print("=" * 60)
39
+
40
+ # Load dataset
41
+ print("\nLoading dataset...")
42
+ dataset = load_dataset(DATASET_NAME)
43
+ train_dataset = dataset["train"]
44
+ eval_dataset = dataset["validation"]
45
+
46
+ print(f"Train examples: {len(train_dataset):,}")
47
+ print(f"Eval examples: {len(eval_dataset):,}")
48
+
49
+ # Format dataset for instruction tuning
50
+ def format_instruction(example):
51
+ """Format examples as instruction-following prompts"""
52
+ instruction = example["instruction"]
53
+ input_text = example.get("input", "")
54
+ output = example["output"]
55
+
56
+ if input_text:
57
+ prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
58
+ else:
59
+ prompt = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
60
+
61
+ return {"text": prompt}
62
+
63
+ print("\nFormatting dataset...")
64
+ train_dataset = train_dataset.map(format_instruction, remove_columns=train_dataset.column_names)
65
+ eval_dataset = eval_dataset.map(format_instruction, remove_columns=eval_dataset.column_names)
66
+
67
+ # Load tokenizer
68
+ print("\nLoading tokenizer...")
69
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
70
+ tokenizer.pad_token = tokenizer.eos_token
71
+ tokenizer.padding_side = "right"
72
+
73
+ # QLoRA configuration for efficient training
74
+ print("Configuring QLoRA...")
75
+ bnb_config = BitsAndBytesConfig(
76
+ load_in_4bit=True,
77
+ bnb_4bit_quant_type="nf4",
78
+ bnb_4bit_compute_dtype=torch.bfloat16,
79
+ bnb_4bit_use_double_quant=True,
80
+ )
81
+
82
+ # Load model
83
+ print("Loading model...")
84
+ model = AutoModelForCausalLM.from_pretrained(
85
+ BASE_MODEL,
86
+ quantization_config=bnb_config,
87
+ device_map="auto",
88
+ trust_remote_code=True,
89
+ )
90
+ model.config.use_cache = False
91
+ model.config.pretraining_tp = 1
92
+
93
+ # LoRA configuration
94
+ print("Configuring LoRA adapters...")
95
+ peft_config = LoraConfig(
96
+ r=64,
97
+ lora_alpha=16,
98
+ lora_dropout=0.1,
99
+ bias="none",
100
+ task_type="CAUSAL_LM",
101
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
102
+ )
103
+
104
+ # Training configuration
105
+ print("Configuring training...")
106
+ training_args = SFTConfig(
107
+ output_dir=OUTPUT_MODEL,
108
+
109
+ # Training parameters
110
+ num_train_epochs=3,
111
+ per_device_train_batch_size=4,
112
+ per_device_eval_batch_size=4,
113
+ gradient_accumulation_steps=4,
114
+ gradient_checkpointing=True,
115
+
116
+ # Optimization
117
+ learning_rate=2e-4,
118
+ lr_scheduler_type="cosine",
119
+ warmup_ratio=0.1,
120
+ weight_decay=0.01,
121
+ optim="paged_adamw_8bit",
122
+
123
+ # Evaluation and logging
124
+ eval_strategy="steps",
125
+ eval_steps=50,
126
+ logging_steps=10,
127
+ save_strategy="steps",
128
+ save_steps=200,
129
+ save_total_limit=3,
130
+
131
+ # Hub integration
132
+ push_to_hub=True,
133
+ hub_model_id=OUTPUT_MODEL,
134
+ hub_strategy="every_save",
135
+ hub_private_repo=False,
136
+
137
+ # Tracking
138
+ report_to="trackio",
139
+ run_name="infrastructure-assistant-qwen-7b",
140
+
141
+ # Performance
142
+ bf16=True,
143
+ max_grad_norm=0.3,
144
+ group_by_length=True,
145
+
146
+ # Misc
147
+ seed=42,
148
+ )
149
+
150
+ # Initialize trainer
151
+ print("Initializing trainer...")
152
+ trainer = SFTTrainer(
153
+ model=model,
154
+ train_dataset=train_dataset,
155
+ eval_dataset=eval_dataset,
156
+ peft_config=peft_config,
157
+ tokenizer=tokenizer,
158
+ args=training_args,
159
+ max_seq_length=2048,
160
+ dataset_text_field="text",
161
+ packing=False,
162
+ )
163
+
164
+ # Train
165
+ print("\n" + "=" * 60)
166
+ print("Starting training...")
167
+ print("=" * 60)
168
+
169
+ trainer.train()
170
+
171
+ # Save final model
172
+ print("\nSaving final model...")
173
+ trainer.save_model()
174
+
175
+ # Push to Hub
176
+ print("Pushing to Hub...")
177
+ trainer.push_to_hub()
178
+
179
+ print("\n" + "=" * 60)
180
+ print("Training complete!")
181
+ print("=" * 60)
182
+ print(f"Model saved to: https://huggingface.co/{OUTPUT_MODEL}")
183
+ print("\nNext steps:")
184
+ print(" 1. Test the model on HuggingFace Hub")
185
+ print(" 2. Convert to GGUF for Ollama deployment")
186
+ print(" 3. Deploy to your infrastructure")