File size: 4,078 Bytes
a037b2e
f3dd33f
a037b2e
 
 
 
a31ad0e
a037b2e
 
 
63a20c4
a31ad0e
63a20c4
a037b2e
 
a31ad0e
 
 
 
 
 
 
 
63a20c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a037b2e
63a20c4
a037b2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1a0f10
a037b2e
 
63a20c4
c1a0f10
a037b2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63a20c4
a037b2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "torch>=2.0.0", "transformers>=4.40.0", "accelerate>=0.20.0", "bitsandbytes>=0.41.0", "protobuf>=3.20.0", "sentencepiece>=0.1.99"]
# ///

"""Fine-tune Mistral-7B-Instruct-v0.3 on NATO doctrine dataset."""

import os
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import login
import torch
import trackio

# Authenticate with HF Hub using token from secrets
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
    login(token=hf_token)
    print("✓ Logged in to Hugging Face Hub")
else:
    print("⚠ Warning: HF_TOKEN not found in environment")

# Model ID
model_id = "mistralai/Mistral-7B-Instruct-v0.3"

# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load model with 4-bit quantization
print("Loading model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)
model.config.use_cache = False
model.gradient_checkpointing_enable()

print(f"✓ Model loaded: {model_id}")

# Load dataset from HF Hub
print("\nLoading NATO doctrine dataset...")
dataset = load_dataset("AndreasThinks/nato-doctrine-sft", split="train")
dataset_test = load_dataset("AndreasThinks/nato-doctrine-sft", split="test")

print(f"✓ Train set: {len(dataset)} examples")
print(f"✓ Test set: {len(dataset_test)} examples")

# Configure LoRA for efficient fine-tuning
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

# Training configuration
training_args = SFTConfig(
    output_dir="nato-ministral-3b",

    # Model saving
    push_to_hub=True,
    hub_model_id="AndreasThinks/mistral-7b-nato-doctrine",
    hub_strategy="every_save",
    hub_private_repo=False,

    # Training parameters
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,  # Effective batch size = 16
    gradient_checkpointing=True,

    # Learning rate
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,

    # Optimization
    optim="adamw_torch",
    weight_decay=0.01,
    max_grad_norm=1.0,

    # Evaluation
    eval_strategy="steps",
    eval_steps=50,

    # Logging and saving
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=3,

    # Monitoring with Trackio
    report_to="trackio",
    run_name="nato-mistral-7b-v1",
    project="nato-doctrine-training",

    # Other
    bf16=True,  # Use bfloat16 for better stability
    seed=42,
)

# Initialize trainer with loaded model
print("\n✓ Initializing SFT trainer...")
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
    eval_dataset=dataset_test,
    peft_config=peft_config,
    args=training_args,
)

# Start training
print("\n✓ Starting training...")
print(f"  Model: mistralai/Mistral-7B-Instruct-v0.3")
print(f"  Training examples: {len(dataset)}")
print(f"  Test examples: {len(dataset_test)}")
print(f"  Epochs: 3")
print(f"  LoRA rank: 16")
print(f"  Output: AndreasThinks/mistral-7b-nato-doctrine\n")

trainer.train()

# Save final model
print("\n✓ Training complete! Saving final model...")
trainer.push_to_hub()

print("\n✅ Fine-tuning complete!")
print(f"  Model: https://huggingface.co/AndreasThinks/mistral-7b-nato-doctrine")
print(f"  Base: mistralai/Mistral-7B-Instruct-v0.3")
print(f"  Trackio: Check your dashboard for metrics")