sunkencity commited on
Commit
33b1a64
ยท
verified ยท
1 Parent(s): f2890b3

Upload train_aviation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_aviation.py +13 -8
train_aviation.py CHANGED
@@ -23,17 +23,23 @@ model_id = "mistralai/Ministral-3-14B-Reasoning-2512" # Defined at top level
23
  from datasets import load_dataset
24
  from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
25
  from trl import SFTTrainer, SFTConfig
26
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, MistralConfig
27
 
28
- # Explicitly register 'ministral3' model type to MistralConfig
29
- # This is a workaround for transformers not recognizing 'ministral3' internally
30
- class RegistrableMinistralConfig(MistralConfig):
31
  model_type = "ministral3"
32
 
33
  AutoConfig.register("ministral3", RegistrableMinistralConfig)
34
  print("๐Ÿ”ง Registered 'ministral3' to RegistrableMinistralConfig.")
35
 
36
- # Register Mistral3Config to a model class
 
 
 
 
 
 
 
37
  print("๐Ÿ”ง Registering Mistral3 model class with AutoModelForCausalLM...")
38
  try:
39
  from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
@@ -42,8 +48,7 @@ try:
42
  print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration")
43
  except ImportError as e:
44
  print(f" โŒ Failed to import Mistral3 modeling classes: {e}")
45
- # Fallback if specific Mistral3 classes are not directly importable
46
- print(" Trying to register Mistral3Config to standard MistralForCausalLM as fallback.")
47
  from transformers import MistralForCausalLM
48
  try:
49
  AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
@@ -178,4 +183,4 @@ print("๐Ÿš€ Starting training...")
178
  trainer.train()
179
 
180
  print("๐Ÿ’พ Pushing to Hub...")
181
- trainer.push_to_hub()
 
23
  from datasets import load_dataset
24
  from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
25
  from trl import SFTTrainer, SFTConfig
26
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, MistralConfig, MinistralModel, AutoModel
27
 
28
+ # Explicitly register 'ministral3' model type to MistralConfig for the nested text config
29
+ class RegistrableMinistralConfig(MistralConfig): # Subclass from MistralConfig (base)
 
30
  model_type = "ministral3"
31
 
32
  AutoConfig.register("ministral3", RegistrableMinistralConfig)
33
  print("๐Ÿ”ง Registered 'ministral3' to RegistrableMinistralConfig.")
34
 
35
+ # Register RegistrableMinistralConfig with AutoModel so Mistral3Model can load its language_model
36
+ try:
37
+ AutoModel.register(RegistrableMinistralConfig, MinistralModel)
38
+ print("๐Ÿ”ง Registered RegistrableMinistralConfig to MinistralModel for AutoModel.")
39
+ except Exception as e:
40
+ print(f" โŒ Failed to register RegistrableMinistralConfig with AutoModel: {e}")
41
+
42
+ # Register Mistral3Config to its model class for AutoModelForCausalLM
43
  print("๐Ÿ”ง Registering Mistral3 model class with AutoModelForCausalLM...")
44
  try:
45
  from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
 
48
  print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration")
49
  except ImportError as e:
50
  print(f" โŒ Failed to import Mistral3 modeling classes: {e}")
51
+ print(" Attempting fallback registration for Mistral3Config with standard MistralForCausalLM.")
 
52
  from transformers import MistralForCausalLM
53
  try:
54
  AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
 
183
  trainer.train()
184
 
185
  print("๐Ÿ’พ Pushing to Hub...")
186
+ trainer.push_to_hub()