sunkencity commited on
Commit
3b8ec8c
Β·
verified Β·
1 Parent(s): 33b1a64

Upload train_aviation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_aviation.py +56 -68
train_aviation.py CHANGED
@@ -17,65 +17,76 @@ import torch
17
  import os
18
  from huggingface_hub import list_repo_files
19
 
20
- model_id = "mistralai/Ministral-3-14B-Reasoning-2512" # Defined at top level
21
-
22
 
23
  from datasets import load_dataset
24
  from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
25
  from trl import SFTTrainer, SFTConfig
26
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, MistralConfig, MinistralModel, AutoModel
27
-
28
- # Explicitly register 'ministral3' model type to MistralConfig for the nested text config
29
- class RegistrableMinistralConfig(MistralConfig): # Subclass from MistralConfig (base)
30
- model_type = "ministral3"
 
 
 
31
 
32
- AutoConfig.register("ministral3", RegistrableMinistralConfig)
33
- print("πŸ”§ Registered 'ministral3' to RegistrableMinistralConfig.")
 
 
34
 
35
- # Register RegistrableMinistralConfig with AutoModel so Mistral3Model can load its language_model
36
  try:
37
- AutoModel.register(RegistrableMinistralConfig, MinistralModel)
38
- print("πŸ”§ Registered RegistrableMinistralConfig to MinistralModel for AutoModel.")
39
- except Exception as e:
40
- print(f" ❌ Failed to register RegistrableMinistralConfig with AutoModel: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Register Mistral3Config to its model class for AutoModelForCausalLM
43
- print("πŸ”§ Registering Mistral3 model class with AutoModelForCausalLM...")
44
- try:
45
- from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
46
- from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
47
- AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration)
48
- print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration")
49
  except ImportError as e:
50
- print(f" ❌ Failed to import Mistral3 modeling classes: {e}")
51
- print(" Attempting fallback registration for Mistral3Config with standard MistralForCausalLM.")
52
- from transformers import MistralForCausalLM
53
- try:
54
- AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
55
- print(" Registered Mistral3Config -> MistralForCausalLM (fallback)")
56
- except Exception as fallback_e:
57
- print(f" ❌ Fallback registration also failed: {fallback_e}")
58
 
 
 
 
59
 
60
  # Load dataset
61
  print("πŸ“¦ Loading dataset...")
62
  dataset = load_dataset("sakharamg/AviationQA", split="train")
63
 
64
- # Limit dataset size for reasonable training time (e.g., 10k examples)
65
- # 1M rows is too large for a single generic fine-tuning job without massive compute.
66
  print("βœ‚οΈ Subsampling dataset to 10,000 examples for efficiency...")
67
- dataset = dataset.shuffle(seed=42).select(range(12000)) # Take slightly more to account for filtering
68
 
69
- # Filter out empty/null examples to prevent chat template errors
70
  print("🧹 Filtering invalid examples...")
71
  dataset = dataset.filter(lambda x: x["Question"] and x["Answer"] and len(x["Question"].strip()) > 0 and len(x["Answer"].strip()) > 0)
72
- print(f" Remaining examples after filtering: {len(dataset)}")
73
-
74
- # Limit to final count
75
  if len(dataset) > 10000:
76
  dataset = dataset.select(range(10000))
77
 
78
- # Map to chat format
79
  print("πŸ”„ Mapping dataset...")
80
  def to_messages(example):
81
  return {
@@ -86,13 +97,11 @@ def to_messages(example):
86
  }
87
  dataset = dataset.map(to_messages, remove_columns=dataset.column_names)
88
 
89
- # Split
90
  print("πŸ”€ Creating train/eval split...")
91
  dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
92
  train_dataset = dataset_split["train"]
93
  eval_dataset = dataset_split["test"]
94
 
95
- # Quantization Config (4-bit for memory efficiency)
96
  bnb_config = BitsAndBytesConfig(
97
  load_in_4bit=True,
98
  bnb_4bit_quant_type="nf4",
@@ -100,42 +109,23 @@ bnb_config = BitsAndBytesConfig(
100
  bnb_4bit_use_double_quant=True,
101
  )
102
 
103
- # Load config first
104
- print(f"πŸ€– Loading config for {model_id}...")
105
- config = AutoConfig.from_pretrained(model_id)
106
-
107
- # Patch text_config to include sliding_window and layer_types
108
- print("πŸ”§ Patching config.text_config...")
109
- if hasattr(config, 'text_config'):
110
- if not hasattr(config.text_config, 'sliding_window') or config.text_config.sliding_window is None:
111
- config.text_config.sliding_window = 4096
112
- print(" Set config.text_config.sliding_window = 4096")
113
- if not hasattr(config.text_config, 'layer_types'):
114
- config.text_config.layer_types = ["sliding_attention"] * getattr(config.text_config, "num_hidden_layers", 40)
115
- print(" Set config.text_config.layer_types")
116
- else:
117
- print(" No text_config found, skipping patching.")
118
-
119
- # Load Model with the config
120
- print(f"πŸ€– Loading model {model_id} with config...")
121
  model = AutoModelForCausalLM.from_pretrained(
122
  model_id,
123
- config=config, # Pass the config
124
  quantization_config=bnb_config,
125
  device_map="auto",
126
  torch_dtype=torch.bfloat16,
127
- attn_implementation="eager" # Default attention for compatibility
128
  )
129
  model = prepare_model_for_kbit_training(model)
130
 
131
- # Tokenizer
132
  tokenizer = AutoTokenizer.from_pretrained(model_id)
133
  tokenizer.pad_token = tokenizer.eos_token
134
- # Fix for some models that miss chat_template or padding
135
  if tokenizer.chat_template is None:
136
  tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
137
 
138
- # LoRA Config
139
  peft_config = LoraConfig(
140
  r=16,
141
  lora_alpha=32,
@@ -145,11 +135,10 @@ peft_config = LoraConfig(
145
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
146
  )
147
 
148
- # Training Config
149
  config = SFTConfig(
150
- output_dir="Mistral-3-14B-AviationQA-SFT",
151
  push_to_hub=True,
152
- hub_model_id="sunkencity/Mistral-3-14B-AviationQA-SFT",
153
  hub_strategy="every_save",
154
  num_train_epochs=1,
155
  per_device_train_batch_size=4,
@@ -166,10 +155,9 @@ config = SFTConfig(
166
  project="aviation-qa-tuning",
167
  run_name="mistral-14b-sft-v1",
168
  max_length=2048,
169
- dataset_kwargs={"add_special_tokens": False} # Let tokenizer handle chat template
170
  )
171
 
172
- # Trainer
173
  trainer = SFTTrainer(
174
  model=model,
175
  train_dataset=train_dataset,
@@ -183,4 +171,4 @@ print("πŸš€ Starting training...")
183
  trainer.train()
184
 
185
  print("πŸ’Ύ Pushing to Hub...")
186
- trainer.push_to_hub()
 
17
  import os
18
  from huggingface_hub import list_repo_files
19
 
20
+ model_id = "mistralai/Ministral-3-14B-Reasoning-2512"
 
21
 
22
  from datasets import load_dataset
23
  from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
24
  from trl import SFTTrainer, SFTConfig
25
+ from transformers import (
26
+ AutoModelForCausalLM,
27
+ AutoTokenizer,
28
+ BitsAndBytesConfig,
29
+ AutoConfig,
30
+ AutoModel,
31
+ MistralConfig # Standard Mistral
32
+ )
33
 
34
+ # ------------------------------------------------------------------
35
+ # CRITICAL FIX: Manually wire the Ministral3 Inner Model
36
+ # ------------------------------------------------------------------
37
+ print("πŸ”§ Starting Manual Registration/Wiring...")
38
 
 
39
  try:
40
+ # 1. Import the specific classes for Ministral (Inner Text Model)
41
+ # The traceback confirmed these exist in the installed transformers version
42
+ from transformers.models.ministral.configuration_ministral import MinistralConfig
43
+ from transformers.models.ministral.modeling_ministral import MinistralModel
44
+
45
+ print(" βœ… Found native MinistralConfig and MinistralModel")
46
+
47
+ # 2. Create a Compatibility Config Class
48
+ # The hub config says "model_type": "ministral3", but code expects attributes not in the JSON.
49
+ class Ministral3CompatConfig(MinistralConfig):
50
+ model_type = "ministral3" # Match the JSON
51
+
52
+ def __init__(self, **kwargs):
53
+ super().__init__(**kwargs)
54
+ # Inject missing attributes causing crashes
55
+ if not hasattr(self, 'sliding_window') or self.sliding_window is None:
56
+ self.sliding_window = 4096
57
+ if not hasattr(self, 'layer_types'):
58
+ # Default to sliding_attention for all layers if not specified
59
+ self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40)
60
+
61
+ # 3. Register Config with AutoConfig (So it handles "model_type": "ministral3")
62
+ AutoConfig.register("ministral3", Ministral3CompatConfig)
63
+ print(" βœ… Registered AutoConfig: 'ministral3' -> Ministral3CompatConfig")
64
+
65
+ # 4. Register Model with AutoModel (So AutoModel.from_config knows what to build)
66
+ # THIS WAS THE MISSING PIECE causing "Unrecognized configuration class"
67
+ AutoModel.register(Ministral3CompatConfig, MinistralModel)
68
+ print(" βœ… Registered AutoModel: Ministral3CompatConfig -> MinistralModel")
69
 
 
 
 
 
 
 
 
70
  except ImportError as e:
71
+ print(f" ❌ Failed to import Ministral classes: {e}")
72
+ print(" ⚠️ This usually means the transformers version is too old for Ministral-3.")
 
 
 
 
 
 
73
 
74
+ # ------------------------------------------------------------------
75
+ # Standard Training Setup
76
+ # ------------------------------------------------------------------
77
 
78
  # Load dataset
79
  print("πŸ“¦ Loading dataset...")
80
  dataset = load_dataset("sakharamg/AviationQA", split="train")
81
 
 
 
82
  print("βœ‚οΈ Subsampling dataset to 10,000 examples for efficiency...")
83
+ dataset = dataset.shuffle(seed=42).select(range(12000))
84
 
 
85
  print("🧹 Filtering invalid examples...")
86
  dataset = dataset.filter(lambda x: x["Question"] and x["Answer"] and len(x["Question"].strip()) > 0 and len(x["Answer"].strip()) > 0)
 
 
 
87
  if len(dataset) > 10000:
88
  dataset = dataset.select(range(10000))
89
 
 
90
  print("πŸ”„ Mapping dataset...")
91
  def to_messages(example):
92
  return {
 
97
  }
98
  dataset = dataset.map(to_messages, remove_columns=dataset.column_names)
99
 
 
100
  print("πŸ”€ Creating train/eval split...")
101
  dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
102
  train_dataset = dataset_split["train"]
103
  eval_dataset = dataset_split["test"]
104
 
 
105
  bnb_config = BitsAndBytesConfig(
106
  load_in_4bit=True,
107
  bnb_4bit_quant_type="nf4",
 
109
  bnb_4bit_use_double_quant=True,
110
  )
111
 
112
+ print(f"πŸ€– Loading model {model_id}...")
113
+ # We use AutoModelForCausalLM, which should now handle the outer Mistral3Config
114
+ # and recursively handle the inner Ministral3CompatConfig via our registration above.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  model = AutoModelForCausalLM.from_pretrained(
116
  model_id,
 
117
  quantization_config=bnb_config,
118
  device_map="auto",
119
  torch_dtype=torch.bfloat16,
120
+ attn_implementation="eager"
121
  )
122
  model = prepare_model_for_kbit_training(model)
123
 
 
124
  tokenizer = AutoTokenizer.from_pretrained(model_id)
125
  tokenizer.pad_token = tokenizer.eos_token
 
126
  if tokenizer.chat_template is None:
127
  tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
128
 
 
129
  peft_config = LoraConfig(
130
  r=16,
131
  lora_alpha=32,
 
135
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
136
  )
137
 
 
138
  config = SFTConfig(
139
+ output_dir="Ministral-3-14B-AviationQA-SFT",
140
  push_to_hub=True,
141
+ hub_model_id="sunkencity/Ministral-3-14B-AviationQA-SFT",
142
  hub_strategy="every_save",
143
  num_train_epochs=1,
144
  per_device_train_batch_size=4,
 
155
  project="aviation-qa-tuning",
156
  run_name="mistral-14b-sft-v1",
157
  max_length=2048,
158
+ dataset_kwargs={"add_special_tokens": False}
159
  )
160
 
 
161
  trainer = SFTTrainer(
162
  model=model,
163
  train_dataset=train_dataset,
 
171
  trainer.train()
172
 
173
  print("πŸ’Ύ Pushing to Hub...")
174
+ trainer.push_to_hub()