File size: 5,770 Bytes
3e19754 9b26fb6 3e19754 7c82449 9b26fb6 3e19754 18a3f9a 3b8ec8c 18a3f9a 3e19754 3b8ec8c 037cd7b 3b8ec8c 5c34124 3b8ec8c 20566c4 3b8ec8c 18a3f9a 33b1a64 20566c4 3b8ec8c 20566c4 3b8ec8c 20566c4 3b8ec8c 20566c4 037cd7b 20566c4 3b8ec8c 037cd7b 20566c4 33b1a64 f2890b3 20566c4 f2890b3 3b8ec8c 3e19754 3b8ec8c 6f6fc96 3e19754 3b8ec8c 3e19754 3b8ec8c 3e19754 3b8ec8c 3e19754 3b8ec8c 3e19754 bfbfcaf 3b8ec8c 3e19754 037cd7b 3e19754 afbbcb3 3e19754 20566c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# /// script
# dependencies = [
# "torch",
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers", # Let UV pick latest
# "huggingface_hub>=0.26.0",
# "accelerate>=0.24.0",
# "trackio",
# "bitsandbytes",
# "scipy",
# ]
# ///
import trackio
import torch
import os
from huggingface_hub import list_repo_files
model_id = "mistralai/Ministral-3-14B-Reasoning-2512"
from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer, SFTConfig
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
AutoConfig,
AutoModel,
MistralConfig
)
# ------------------------------------------------------------------
# CRITICAL FIX: Manually wire the Ministral3 Hierarchy
# ------------------------------------------------------------------
print("π§ Starting Manual Registration/Wiring...")
try:
# --- 1. Inner Text Model (Ministral) ---
from transformers.models.ministral.configuration_ministral import MinistralConfig
from transformers.models.ministral.modeling_ministral import MinistralModel
# Compatibility Config for Inner Model
class Ministral3CompatConfig(MinistralConfig):
model_type = "ministral3"
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not hasattr(self, 'sliding_window') or self.sliding_window is None:
self.sliding_window = 4096
if not hasattr(self, 'layer_types'):
self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40)
# Compatibility Model for Inner Model
class Ministral3CompatModel(MinistralModel):
config_class = Ministral3CompatConfig
# Register Inner Components
AutoConfig.register("ministral3", Ministral3CompatConfig)
AutoModel.register(Ministral3CompatConfig, Ministral3CompatModel)
print(" β
Registered Inner: 'ministral3' -> Ministral3CompatModel")
# --- 2. Outer Multimodal Model (Mistral3) ---
from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
# Register Outer Components with AutoModelForCausalLM
AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration)
print(" β
Registered Outer: Mistral3Config -> Mistral3ForConditionalGeneration")
except ImportError as e:
print(f" β Failed to import/register classes: {e}")
print(" β οΈ This usually means the transformers version is too old or incompatible.")
# ------------------------------------------------------------------
# Standard Training Setup
# ------------------------------------------------------------------
# Load dataset
print("π¦ Loading dataset...")
dataset = load_dataset("sakharamg/AviationQA", split="train")
print("βοΈ Subsampling dataset to 10,000 examples for efficiency...")
dataset = dataset.shuffle(seed=42).select(range(12000))
print("π§Ή Filtering invalid examples...")
dataset = dataset.filter(lambda x: x["Question"] and x["Answer"] and len(x["Question"].strip()) > 0 and len(x["Answer"].strip()) > 0)
if len(dataset) > 10000:
dataset = dataset.select(range(10000))
print("π Mapping dataset...")
def to_messages(example):
return {
"messages": [
{"role": "user", "content": example["Question"]},
{"role": "assistant", "content": example["Answer"]}
]
}
dataset = dataset.map(to_messages, remove_columns=dataset.column_names)
print("π Creating train/eval split...")
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
print(f"π€ Loading model {model_id}...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager"
)
model = prepare_model_for_kbit_training(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
config = SFTConfig(
output_dir="Ministral-3-14B-AviationQA-SFT",
push_to_hub=True,
hub_model_id="sunkencity/Ministral-3-14B-AviationQA-SFT",
hub_strategy="every_save",
num_train_epochs=1,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
fp16=False,
bf16=True,
logging_steps=10,
save_strategy="steps",
save_steps=100,
eval_strategy="steps",
eval_steps=100,
report_to="trackio",
project="aviation-qa-tuning",
run_name="mistral-14b-sft-v1",
max_length=2048,
dataset_kwargs={"add_special_tokens": False}
)
# Trainer
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=config,
peft_config=peft_config,
processing_class=tokenizer,
)
print("π Starting training...")
trainer.train()
print("πΎ Pushing to Hub...")
trainer.push_to_hub() |