File size: 6,159 Bytes
3e19754 9b26fb6 3e19754 9b26fb6 3e19754 18a3f9a 31bab64 18a3f9a 3e19754 5c34124 8259512 fdda1e0 8259512 fdda1e0 304ee13 fdda1e0 304ee13 8259512 fdda1e0 8259512 63a87b1 d815ac3 304ee13 105152c fdda1e0 63a87b1 5c34124 63a87b1 3e19754 18a3f9a 3e19754 fe50797 3e19754 bfbfcaf 3e19754 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | # /// script
# dependencies = [
# "torch",
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.46.0",
# "huggingface_hub>=0.26.0",
# "accelerate>=0.24.0",
# "trackio",
# "bitsandbytes",
# "scipy",
# ]
# ///
import trackio
import torch
import os
from huggingface_hub import list_repo_files
# DEBUG: Check token and repo access
print("π DIAGNOSTICS:")
token = os.environ.get("HF_TOKEN")
print(f" HF_TOKEN env var present: {bool(token)}")
if token:
print(f" HF_TOKEN prefix: {token[:4]}...")
model_id = "mistralai/Ministral-3-14B-Reasoning-2512"
try:
print(f" Attempting to list files for {model_id}...")
files = list_repo_files(model_id, token=token)
print(f" β
Success! Found {len(files)} files.")
print(f" First 5 files: {files[:5]}")
except Exception as e:
print(f" β Failed to list repo files: {e}")
print("="*40)
from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
# Register 'ministral3' config to handle nested text_config
print("π§ Registering ministral3 config (Monkey Patch Strategy)...")
try:
from transformers import MinistralConfig, AutoConfig
# Monkey patch the model_type to match what the config.json has
# This allows us to use the native class which is already registered with AutoModel
print(f" Original MinistralConfig.model_type: {MinistralConfig.model_type}")
MinistralConfig.model_type = "ministral3"
print(f" Patched MinistralConfig.model_type: {MinistralConfig.model_type}")
# Register the patched class for the "ministral3" key
AutoConfig.register("ministral3", MinistralConfig)
print(" Registered ministral3 -> MinistralConfig (native, patched)")
except Exception as e:
print(f" β Failed to patch/register ministral3 config: {e}")
# Register Mistral3Config to a model class
# ... (rest of registration kept as is)
# ... (rest of registration kept as is)
# ... (rest of registration kept as is)
# ... (rest of registration kept as is)
print("π§ Registering Mistral3 model class...")
try:
from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
try:
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration)
print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration")
except ImportError:
print(" Mistral3ForConditionalGeneration not found, trying MistralForCausalLM")
from transformers import MistralForCausalLM
AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
print(" Registered Mistral3Config -> MistralForCausalLM")
except ImportError as e:
print(f" β Failed to find Mistral3Config or register model: {e}")
# Model ID
# model_id defined above
# Load dataset
print("π¦ Loading dataset...")
dataset = load_dataset("sakharamg/AviationQA", split="train")
# Limit dataset size for reasonable training time (e.g., 10k examples)
# 1M rows is too large for a single generic fine-tuning job without massive compute.
print("βοΈ Subsampling dataset to 10,000 examples for efficiency...")
dataset = dataset.shuffle(seed=42).select(range(10000))
# Map to chat format
print("π Mapping dataset...")
def to_messages(example):
return {
"messages": [
{"role": "user", "content": example["Question"]},
{"role": "assistant", "content": example["Answer"]}
]
}
dataset = dataset.map(to_messages, remove_columns=dataset.column_names)
# Split
print("π Creating train/eval split...")
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
# Quantization Config (4-bit for memory efficiency)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load Model
print(f"π€ Loading model {model_id}...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager" # Default attention for compatibility
)
model = prepare_model_for_kbit_training(model)
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Fix for some models that miss chat_template or padding
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
# LoRA Config
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
# Training Config
config = SFTConfig(
output_dir="Mistral-3-14B-AviationQA-SFT",
push_to_hub=True,
hub_model_id="sunkencity/Mistral-3-14B-AviationQA-SFT",
hub_strategy="every_save",
num_train_epochs=1,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
fp16=False,
bf16=True,
logging_steps=10,
save_strategy="steps",
save_steps=100,
eval_strategy="steps",
eval_steps=100,
report_to="trackio",
project="aviation-qa-tuning",
run_name="mistral-14b-sft-v1",
max_length=2048,
dataset_kwargs={"add_special_tokens": False} # Let tokenizer handle chat template
)
# Trainer
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=config,
peft_config=peft_config,
tokenizer=tokenizer,
)
print("π Starting training...")
trainer.train()
print("πΎ Pushing to Hub...")
trainer.push_to_hub()
|