File size: 6,159 Bytes
3e19754
 
9b26fb6
3e19754
 
9b26fb6
 
3e19754
 
 
 
 
 
 
 
 
18a3f9a
 
 
 
31bab64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a3f9a
3e19754
 
 
5c34124
 
8259512
fdda1e0
8259512
fdda1e0
304ee13
fdda1e0
 
 
 
 
 
 
 
 
304ee13
8259512
fdda1e0
8259512
63a87b1
d815ac3
304ee13
105152c
fdda1e0
63a87b1
5c34124
63a87b1
 
 
 
 
 
 
 
 
 
 
 
3e19754
 
18a3f9a
 
3e19754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe50797
3e19754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfbfcaf
3e19754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# /// script
# dependencies = [
#     "torch", 
#     "trl>=0.12.0",
#     "peft>=0.7.0",
#     "transformers>=4.46.0",
#     "huggingface_hub>=0.26.0",
#     "accelerate>=0.24.0",
#     "trackio",
#     "bitsandbytes",
#     "scipy",
# ]
# ///

import trackio
import torch
import os
from huggingface_hub import list_repo_files

# DEBUG: Check token and repo access
print("πŸ” DIAGNOSTICS:")
token = os.environ.get("HF_TOKEN")
print(f"   HF_TOKEN env var present: {bool(token)}")
if token:
    print(f"   HF_TOKEN prefix: {token[:4]}...")

model_id = "mistralai/Ministral-3-14B-Reasoning-2512"
try:
    print(f"   Attempting to list files for {model_id}...")
    files = list_repo_files(model_id, token=token)
    print(f"   βœ… Success! Found {len(files)} files.")
    print(f"   First 5 files: {files[:5]}")
except Exception as e:
    print(f"   ❌ Failed to list repo files: {e}")
print("="*40)

from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig

# Register 'ministral3' config to handle nested text_config
print("πŸ”§ Registering ministral3 config (Monkey Patch Strategy)...")
try:
    from transformers import MinistralConfig, AutoConfig
    
    # Monkey patch the model_type to match what the config.json has
    # This allows us to use the native class which is already registered with AutoModel
    print(f"   Original MinistralConfig.model_type: {MinistralConfig.model_type}")
    MinistralConfig.model_type = "ministral3"
    print(f"   Patched MinistralConfig.model_type: {MinistralConfig.model_type}")
    
    # Register the patched class for the "ministral3" key
    AutoConfig.register("ministral3", MinistralConfig)
    print("   Registered ministral3 -> MinistralConfig (native, patched)")

except Exception as e:
    print(f"   ❌ Failed to patch/register ministral3 config: {e}")

# Register Mistral3Config to a model class
# ... (rest of registration kept as is)
# ... (rest of registration kept as is)
# ... (rest of registration kept as is)
# ... (rest of registration kept as is)
print("πŸ”§ Registering Mistral3 model class...")
try:
    from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
    try:
        from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
        AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration)
        print("   Registered Mistral3Config -> Mistral3ForConditionalGeneration")
    except ImportError:
        print("   Mistral3ForConditionalGeneration not found, trying MistralForCausalLM")
        from transformers import MistralForCausalLM
        AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
        print("   Registered Mistral3Config -> MistralForCausalLM")
except ImportError as e:
    print(f"   ❌ Failed to find Mistral3Config or register model: {e}")

# Model ID
# model_id defined above


# Load dataset
print("πŸ“¦ Loading dataset...")
dataset = load_dataset("sakharamg/AviationQA", split="train")

# Limit dataset size for reasonable training time (e.g., 10k examples)
# 1M rows is too large for a single generic fine-tuning job without massive compute.
print("βœ‚οΈ Subsampling dataset to 10,000 examples for efficiency...")
dataset = dataset.shuffle(seed=42).select(range(10000))

# Map to chat format
print("πŸ”„ Mapping dataset...")
def to_messages(example):
    return {
        "messages": [
            {"role": "user", "content": example["Question"]},
            {"role": "assistant", "content": example["Answer"]}
        ]
    }
dataset = dataset.map(to_messages, remove_columns=dataset.column_names)

# Split
print("πŸ”€ Creating train/eval split...")
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]

# Quantization Config (4-bit for memory efficiency)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load Model
print(f"πŸ€– Loading model {model_id}...")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="eager" # Default attention for compatibility
)
model = prepare_model_for_kbit_training(model)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Fix for some models that miss chat_template or padding
if tokenizer.chat_template is None:
    tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

# LoRA Config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

# Training Config
config = SFTConfig(
    output_dir="Mistral-3-14B-AviationQA-SFT",
    push_to_hub=True,
    hub_model_id="sunkencity/Mistral-3-14B-AviationQA-SFT",
    hub_strategy="every_save",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    fp16=False,
    bf16=True,
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
    eval_strategy="steps",
    eval_steps=100,
    report_to="trackio",
    project="aviation-qa-tuning",
    run_name="mistral-14b-sft-v1",
    max_length=2048,
    dataset_kwargs={"add_special_tokens": False} # Let tokenizer handle chat template
)

# Trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=config,
    peft_config=peft_config,
    tokenizer=tokenizer,
)

print("πŸš€ Starting training...")
trainer.train()

print("πŸ’Ύ Pushing to Hub...")
trainer.push_to_hub()