nato-llm-scripts / train_mistral.py
AndreasThinks's picture
Upload train_mistral.py with huggingface_hub
a31ad0e verified
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "torch>=2.0.0", "transformers>=4.40.0", "accelerate>=0.20.0", "bitsandbytes>=0.41.0", "protobuf>=3.20.0", "sentencepiece>=0.1.99"]
# ///
"""Fine-tune Mistral-7B-Instruct-v0.3 on NATO doctrine dataset."""
import os
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import login
import torch
import trackio
# Authenticate with HF Hub using token from secrets
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("✓ Logged in to Hugging Face Hub")
else:
print("âš  Warning: HF_TOKEN not found in environment")
# Model ID
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Load model with 4-bit quantization
print("Loading model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
model.config.use_cache = False
model.gradient_checkpointing_enable()
print(f"✓ Model loaded: {model_id}")
# Load dataset from HF Hub
print("\nLoading NATO doctrine dataset...")
dataset = load_dataset("AndreasThinks/nato-doctrine-sft", split="train")
dataset_test = load_dataset("AndreasThinks/nato-doctrine-sft", split="test")
print(f"✓ Train set: {len(dataset)} examples")
print(f"✓ Test set: {len(dataset_test)} examples")
# Configure LoRA for efficient fine-tuning
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
# Training configuration
training_args = SFTConfig(
output_dir="nato-ministral-3b",
# Model saving
push_to_hub=True,
hub_model_id="AndreasThinks/mistral-7b-nato-doctrine",
hub_strategy="every_save",
hub_private_repo=False,
# Training parameters
num_train_epochs=3,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=8, # Effective batch size = 16
gradient_checkpointing=True,
# Learning rate
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
# Optimization
optim="adamw_torch",
weight_decay=0.01,
max_grad_norm=1.0,
# Evaluation
eval_strategy="steps",
eval_steps=50,
# Logging and saving
logging_steps=10,
save_strategy="steps",
save_steps=100,
save_total_limit=3,
# Monitoring with Trackio
report_to="trackio",
run_name="nato-mistral-7b-v1",
project="nato-doctrine-training",
# Other
bf16=True, # Use bfloat16 for better stability
seed=42,
)
# Initialize trainer with loaded model
print("\n✓ Initializing SFT trainer...")
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
eval_dataset=dataset_test,
peft_config=peft_config,
args=training_args,
)
# Start training
print("\n✓ Starting training...")
print(f" Model: mistralai/Mistral-7B-Instruct-v0.3")
print(f" Training examples: {len(dataset)}")
print(f" Test examples: {len(dataset_test)}")
print(f" Epochs: 3")
print(f" LoRA rank: 16")
print(f" Output: AndreasThinks/mistral-7b-nato-doctrine\n")
trainer.train()
# Save final model
print("\n✓ Training complete! Saving final model...")
trainer.push_to_hub()
print("\n✅ Fine-tuning complete!")
print(f" Model: https://huggingface.co/AndreasThinks/mistral-7b-nato-doctrine")
print(f" Base: mistralai/Mistral-7B-Instruct-v0.3")
print(f" Trackio: Check your dashboard for metrics")