# /// script # dependencies = [ # "torch", # "trl>=0.12.0", # "peft>=0.7.0", # "transformers", # Let UV pick latest # "huggingface_hub>=0.26.0", # "accelerate>=0.24.0", # "trackio", # "bitsandbytes", # "scipy", # ] # /// import trackio import torch import os from huggingface_hub import list_repo_files model_id = "mistralai/Ministral-3-14B-Reasoning-2512" from datasets import load_dataset from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model from trl import SFTTrainer, SFTConfig from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, AutoModel, MistralConfig ) # ------------------------------------------------------------------ # CRITICAL FIX: Manually wire the Ministral3 Hierarchy # ------------------------------------------------------------------ print("๐Ÿ”ง Starting Manual Registration/Wiring...") try: # --- 1. Inner Text Model (Ministral) --- from transformers.models.ministral.configuration_ministral import MinistralConfig from transformers.models.ministral.modeling_ministral import MinistralModel # Compatibility Config for Inner Model class Ministral3CompatConfig(MinistralConfig): model_type = "ministral3" def __init__(self, **kwargs): super().__init__(**kwargs) if not hasattr(self, 'sliding_window') or self.sliding_window is None: self.sliding_window = 4096 if not hasattr(self, 'layer_types'): self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40) # Compatibility Model for Inner Model class Ministral3CompatModel(MinistralModel): config_class = Ministral3CompatConfig # Register Inner Components AutoConfig.register("ministral3", Ministral3CompatConfig) AutoModel.register(Ministral3CompatConfig, Ministral3CompatModel) print(" โœ… Registered Inner: 'ministral3' -> Ministral3CompatModel") # --- 2. Outer Multimodal Model (Mistral3) --- from transformers.models.mistral3.configuration_mistral3 import Mistral3Config from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration # Register Outer Components with AutoModelForCausalLM AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration) print(" โœ… Registered Outer: Mistral3Config -> Mistral3ForConditionalGeneration") except ImportError as e: print(f" โŒ Failed to import/register classes: {e}") print(" โš ๏ธ This usually means the transformers version is too old or incompatible.") # ------------------------------------------------------------------ # Standard Training Setup # ------------------------------------------------------------------ # Load dataset print("๐Ÿ“ฆ Loading dataset...") dataset = load_dataset("sakharamg/AviationQA", split="train") print("โœ‚๏ธ Subsampling dataset to 10,000 examples for efficiency...") dataset = dataset.shuffle(seed=42).select(range(12000)) print("๐Ÿงน Filtering invalid examples...") dataset = dataset.filter(lambda x: x["Question"] and x["Answer"] and len(x["Question"].strip()) > 0 and len(x["Answer"].strip()) > 0) if len(dataset) > 10000: dataset = dataset.select(range(10000)) print("๐Ÿ”„ Mapping dataset...") def to_messages(example): return { "messages": [ {"role": "user", "content": example["Question"]}, {"role": "assistant", "content": example["Answer"]} ] } dataset = dataset.map(to_messages, remove_columns=dataset.column_names) print("๐Ÿ”€ Creating train/eval split...") dataset_split = dataset.train_test_split(test_size=0.1, seed=42) train_dataset = dataset_split["train"] eval_dataset = dataset_split["test"] bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) print(f"๐Ÿค– Loading model {model_id}...") model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" ) model = prepare_model_for_kbit_training(model) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token if tokenizer.chat_template is None: tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" peft_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) config = SFTConfig( output_dir="Ministral-3-14B-AviationQA-SFT", push_to_hub=True, hub_model_id="sunkencity/Ministral-3-14B-AviationQA-SFT", hub_strategy="every_save", num_train_epochs=1, per_device_train_batch_size=4, gradient_accumulation_steps=4, learning_rate=2e-4, fp16=False, bf16=True, logging_steps=10, save_strategy="steps", save_steps=100, eval_strategy="steps", eval_steps=100, report_to="trackio", project="aviation-qa-tuning", run_name="mistral-14b-sft-v1", max_length=2048, dataset_kwargs={"add_special_tokens": False} ) # Trainer trainer = SFTTrainer( model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, args=config, peft_config=peft_config, processing_class=tokenizer, ) print("๐Ÿš€ Starting training...") trainer.train() print("๐Ÿ’พ Pushing to Hub...") trainer.push_to_hub()