|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import trackio |
|
|
import torch |
|
|
import os |
|
|
from huggingface_hub import list_repo_files |
|
|
|
|
|
model_id = "mistralai/Ministral-3-14B-Reasoning-2512" |
|
|
|
|
|
from datasets import load_dataset |
|
|
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model |
|
|
from trl import SFTTrainer, SFTConfig |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
BitsAndBytesConfig, |
|
|
AutoConfig, |
|
|
AutoModel, |
|
|
MistralConfig |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("π§ Starting Manual Registration/Wiring...") |
|
|
|
|
|
try: |
|
|
|
|
|
from transformers.models.ministral.configuration_ministral import MinistralConfig |
|
|
from transformers.models.ministral.modeling_ministral import MinistralModel |
|
|
|
|
|
|
|
|
class Ministral3CompatConfig(MinistralConfig): |
|
|
model_type = "ministral3" |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
if not hasattr(self, 'sliding_window') or self.sliding_window is None: |
|
|
self.sliding_window = 4096 |
|
|
if not hasattr(self, 'layer_types'): |
|
|
self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40) |
|
|
|
|
|
|
|
|
class Ministral3CompatModel(MinistralModel): |
|
|
config_class = Ministral3CompatConfig |
|
|
|
|
|
|
|
|
AutoConfig.register("ministral3", Ministral3CompatConfig) |
|
|
AutoModel.register(Ministral3CompatConfig, Ministral3CompatModel) |
|
|
print(" β
Registered Inner: 'ministral3' -> Ministral3CompatModel") |
|
|
|
|
|
|
|
|
|
|
|
from transformers.models.mistral3.configuration_mistral3 import Mistral3Config |
|
|
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration |
|
|
|
|
|
|
|
|
AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration) |
|
|
print(" β
Registered Outer: Mistral3Config -> Mistral3ForConditionalGeneration") |
|
|
|
|
|
except ImportError as e: |
|
|
print(f" β Failed to import/register classes: {e}") |
|
|
print(" β οΈ This usually means the transformers version is too old or incompatible.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("π¦ Loading dataset...") |
|
|
dataset = load_dataset("sakharamg/AviationQA", split="train") |
|
|
|
|
|
print("βοΈ Subsampling dataset to 10,000 examples for efficiency...") |
|
|
dataset = dataset.shuffle(seed=42).select(range(12000)) |
|
|
|
|
|
print("π§Ή Filtering invalid examples...") |
|
|
dataset = dataset.filter(lambda x: x["Question"] and x["Answer"] and len(x["Question"].strip()) > 0 and len(x["Answer"].strip()) > 0) |
|
|
if len(dataset) > 10000: |
|
|
dataset = dataset.select(range(10000)) |
|
|
|
|
|
print("π Mapping dataset...") |
|
|
def to_messages(example): |
|
|
return { |
|
|
"messages": [ |
|
|
{"role": "user", "content": example["Question"]}, |
|
|
{"role": "assistant", "content": example["Answer"]} |
|
|
] |
|
|
} |
|
|
dataset = dataset.map(to_messages, remove_columns=dataset.column_names) |
|
|
|
|
|
print("π Creating train/eval split...") |
|
|
dataset_split = dataset.train_test_split(test_size=0.1, seed=42) |
|
|
train_dataset = dataset_split["train"] |
|
|
eval_dataset = dataset_split["test"] |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
) |
|
|
|
|
|
print(f"π€ Loading model {model_id}...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="eager" |
|
|
) |
|
|
model = prepare_model_for_kbit_training(model) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
if tokenizer.chat_template is None: |
|
|
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" |
|
|
|
|
|
peft_config = LoraConfig( |
|
|
r=16, |
|
|
lora_alpha=32, |
|
|
lora_dropout=0.05, |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM", |
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
|
|
) |
|
|
|
|
|
config = SFTConfig( |
|
|
output_dir="Ministral-3-14B-AviationQA-SFT", |
|
|
push_to_hub=True, |
|
|
hub_model_id="sunkencity/Ministral-3-14B-AviationQA-SFT", |
|
|
hub_strategy="every_save", |
|
|
num_train_epochs=1, |
|
|
per_device_train_batch_size=4, |
|
|
gradient_accumulation_steps=4, |
|
|
learning_rate=2e-4, |
|
|
fp16=False, |
|
|
bf16=True, |
|
|
logging_steps=10, |
|
|
save_strategy="steps", |
|
|
save_steps=100, |
|
|
eval_strategy="steps", |
|
|
eval_steps=100, |
|
|
report_to="trackio", |
|
|
project="aviation-qa-tuning", |
|
|
run_name="mistral-14b-sft-v1", |
|
|
max_length=2048, |
|
|
dataset_kwargs={"add_special_tokens": False} |
|
|
) |
|
|
|
|
|
|
|
|
trainer = SFTTrainer( |
|
|
model=model, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
args=config, |
|
|
peft_config=peft_config, |
|
|
processing_class=tokenizer, |
|
|
) |
|
|
|
|
|
print("π Starting training...") |
|
|
trainer.train() |
|
|
|
|
|
print("πΎ Pushing to Hub...") |
|
|
trainer.push_to_hub() |