|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import trackio |
|
|
import torch |
|
|
import os |
|
|
from huggingface_hub import list_repo_files |
|
|
|
|
|
|
|
|
print("π DIAGNOSTICS:") |
|
|
token = os.environ.get("HF_TOKEN") |
|
|
print(f" HF_TOKEN env var present: {bool(token)}") |
|
|
if token: |
|
|
print(f" HF_TOKEN prefix: {token[:4]}...") |
|
|
|
|
|
model_id = "mistralai/Ministral-3-14B-Reasoning-2512" |
|
|
try: |
|
|
print(f" Attempting to list files for {model_id}...") |
|
|
files = list_repo_files(model_id, token=token) |
|
|
print(f" β
Success! Found {len(files)} files.") |
|
|
print(f" First 5 files: {files[:5]}") |
|
|
except Exception as e: |
|
|
print(f" β Failed to list repo files: {e}") |
|
|
print("="*40) |
|
|
|
|
|
from datasets import load_dataset |
|
|
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model |
|
|
from trl import SFTTrainer, SFTConfig |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig |
|
|
|
|
|
|
|
|
print("π§ Registering ministral3 config (Monkey Patch Strategy)...") |
|
|
try: |
|
|
from transformers import MinistralConfig, AutoConfig |
|
|
|
|
|
|
|
|
|
|
|
print(f" Original MinistralConfig.model_type: {MinistralConfig.model_type}") |
|
|
MinistralConfig.model_type = "ministral3" |
|
|
print(f" Patched MinistralConfig.model_type: {MinistralConfig.model_type}") |
|
|
|
|
|
|
|
|
AutoConfig.register("ministral3", MinistralConfig) |
|
|
print(" Registered ministral3 -> MinistralConfig (native, patched)") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" β Failed to patch/register ministral3 config: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("π§ Registering Mistral3 model class...") |
|
|
try: |
|
|
from transformers.models.mistral3.configuration_mistral3 import Mistral3Config |
|
|
try: |
|
|
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration |
|
|
AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration) |
|
|
print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration") |
|
|
except ImportError: |
|
|
print(" Mistral3ForConditionalGeneration not found, trying MistralForCausalLM") |
|
|
from transformers import MistralForCausalLM |
|
|
AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM) |
|
|
print(" Registered Mistral3Config -> MistralForCausalLM") |
|
|
except ImportError as e: |
|
|
print(f" β Failed to find Mistral3Config or register model: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("π¦ Loading dataset...") |
|
|
dataset = load_dataset("sakharamg/AviationQA", split="train") |
|
|
|
|
|
|
|
|
|
|
|
print("βοΈ Subsampling dataset to 10,000 examples for efficiency...") |
|
|
dataset = dataset.shuffle(seed=42).select(range(10000)) |
|
|
|
|
|
|
|
|
print("π Mapping dataset...") |
|
|
def to_messages(example): |
|
|
return { |
|
|
"messages": [ |
|
|
{"role": "user", "content": example["Question"]}, |
|
|
{"role": "assistant", "content": example["Answer"]} |
|
|
] |
|
|
} |
|
|
dataset = dataset.map(to_messages, remove_columns=dataset.column_names) |
|
|
|
|
|
|
|
|
print("π Creating train/eval split...") |
|
|
dataset_split = dataset.train_test_split(test_size=0.1, seed=42) |
|
|
train_dataset = dataset_split["train"] |
|
|
eval_dataset = dataset_split["test"] |
|
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
) |
|
|
|
|
|
|
|
|
print(f"π€ Loading model {model_id}...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="eager" |
|
|
) |
|
|
model = prepare_model_for_kbit_training(model) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
if tokenizer.chat_template is None: |
|
|
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" |
|
|
|
|
|
|
|
|
peft_config = LoraConfig( |
|
|
r=16, |
|
|
lora_alpha=32, |
|
|
lora_dropout=0.05, |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM", |
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
|
|
) |
|
|
|
|
|
|
|
|
config = SFTConfig( |
|
|
output_dir="Mistral-3-14B-AviationQA-SFT", |
|
|
push_to_hub=True, |
|
|
hub_model_id="sunkencity/Mistral-3-14B-AviationQA-SFT", |
|
|
hub_strategy="every_save", |
|
|
num_train_epochs=1, |
|
|
per_device_train_batch_size=4, |
|
|
gradient_accumulation_steps=4, |
|
|
learning_rate=2e-4, |
|
|
fp16=False, |
|
|
bf16=True, |
|
|
logging_steps=10, |
|
|
save_strategy="steps", |
|
|
save_steps=100, |
|
|
eval_strategy="steps", |
|
|
eval_steps=100, |
|
|
report_to="trackio", |
|
|
project="aviation-qa-tuning", |
|
|
run_name="mistral-14b-sft-v1", |
|
|
max_length=2048, |
|
|
dataset_kwargs={"add_special_tokens": False} |
|
|
) |
|
|
|
|
|
|
|
|
trainer = SFTTrainer( |
|
|
model=model, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
args=config, |
|
|
peft_config=peft_config, |
|
|
tokenizer=tokenizer, |
|
|
) |
|
|
|
|
|
print("π Starting training...") |
|
|
trainer.train() |
|
|
|
|
|
print("πΎ Pushing to Hub...") |
|
|
trainer.push_to_hub() |
|
|
|