infrastructure-training-scripts / train_infrastructure_model.py
lokegud's picture
Upload train_infrastructure_model.py with huggingface_hub
3365f48 verified
#!/usr/bin/env python3
# /// script
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.38.0",
# "datasets>=2.16.0",
# "torch>=2.1.0",
# "accelerate>=0.26.0",
# "bitsandbytes>=0.42.0",
# "trackio>=0.3.0"
# ]
# ///
"""
Infrastructure Security Training - SFT Fine-tuning
Trains Qwen 2.5 7B on infrastructure management tasks
"""
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import torch
import trackio
# Model and dataset configuration
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
DATASET_NAME = "lokegud/infrastructure-security-training"
OUTPUT_MODEL = "lokegud/infrastructure-assistant-7b"
print("=" * 60)
print("Infrastructure Assistant Training")
print("=" * 60)
print(f"Base Model: {BASE_MODEL}")
print(f"Dataset: {DATASET_NAME}")
print(f"Output: {OUTPUT_MODEL}")
print("=" * 60)
# Load dataset
print("\nLoading dataset...")
dataset = load_dataset(DATASET_NAME)
train_dataset = dataset["train"]
eval_dataset = dataset["validation"]
print(f"Train examples: {len(train_dataset):,}")
print(f"Eval examples: {len(eval_dataset):,}")
# Format dataset for instruction tuning
def format_instruction(example):
"""Format examples as instruction-following prompts"""
instruction = example["instruction"]
input_text = example.get("input", "")
output = example["output"]
if input_text:
prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
else:
prompt = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
return {"text": prompt}
print("\nFormatting dataset...")
train_dataset = train_dataset.map(format_instruction, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(format_instruction, remove_columns=eval_dataset.column_names)
# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# QLoRA configuration for efficient training
print("Configuring QLoRA...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load model
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
model.config.use_cache = False
model.config.pretraining_tp = 1
# LoRA configuration
print("Configuring LoRA adapters...")
peft_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
# Training configuration
print("Configuring training...")
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir=OUTPUT_MODEL,
# Training parameters
num_train_epochs=3,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=8,
gradient_checkpointing=True,
# Optimization
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
weight_decay=0.01,
optim="paged_adamw_8bit",
# Evaluation and logging
eval_strategy="steps",
eval_steps=100,
logging_steps=10,
save_strategy="steps",
save_steps=200,
save_total_limit=3,
# Hub integration
push_to_hub=True,
hub_model_id=OUTPUT_MODEL,
hub_strategy="every_save",
hub_private_repo=False,
# Tracking
report_to="trackio",
run_name="infrastructure-assistant-qwen-7b",
# Performance
bf16=True,
max_grad_norm=0.3,
# Misc
seed=42,
)
# Initialize trainer
print("Initializing trainer...")
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
args=training_args,
)
# Train
print("\n" + "=" * 60)
print("Starting training...")
print("=" * 60)
trainer.train()
# Save final model
print("\nSaving final model...")
trainer.save_model()
# Push to Hub
print("Pushing to Hub...")
trainer.push_to_hub()
print("\n" + "=" * 60)
print("Training complete!")
print("=" * 60)
print(f"Model saved to: https://huggingface.co/{OUTPUT_MODEL}")
print("\nNext steps:")
print(" 1. Test the model on HuggingFace Hub")
print(" 2. Convert to GGUF for Ollama deployment")
print(" 3. Deploy to your infrastructure")