Upload train_aviation.py with huggingface_hub
Browse files- train_aviation.py +13 -8
train_aviation.py
CHANGED
|
@@ -23,17 +23,23 @@ model_id = "mistralai/Ministral-3-14B-Reasoning-2512" # Defined at top level
|
|
| 23 |
from datasets import load_dataset
|
| 24 |
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
|
| 25 |
from trl import SFTTrainer, SFTConfig
|
| 26 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, MistralConfig
|
| 27 |
|
| 28 |
-
# Explicitly register 'ministral3' model type to MistralConfig
|
| 29 |
-
|
| 30 |
-
class RegistrableMinistralConfig(MistralConfig):
|
| 31 |
model_type = "ministral3"
|
| 32 |
|
| 33 |
AutoConfig.register("ministral3", RegistrableMinistralConfig)
|
| 34 |
print("๐ง Registered 'ministral3' to RegistrableMinistralConfig.")
|
| 35 |
|
| 36 |
-
# Register
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
print("๐ง Registering Mistral3 model class with AutoModelForCausalLM...")
|
| 38 |
try:
|
| 39 |
from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
|
|
@@ -42,8 +48,7 @@ try:
|
|
| 42 |
print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration")
|
| 43 |
except ImportError as e:
|
| 44 |
print(f" โ Failed to import Mistral3 modeling classes: {e}")
|
| 45 |
-
|
| 46 |
-
print(" Trying to register Mistral3Config to standard MistralForCausalLM as fallback.")
|
| 47 |
from transformers import MistralForCausalLM
|
| 48 |
try:
|
| 49 |
AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
|
|
@@ -178,4 +183,4 @@ print("๐ Starting training...")
|
|
| 178 |
trainer.train()
|
| 179 |
|
| 180 |
print("๐พ Pushing to Hub...")
|
| 181 |
-
trainer.push_to_hub()
|
|
|
|
| 23 |
from datasets import load_dataset
|
| 24 |
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
|
| 25 |
from trl import SFTTrainer, SFTConfig
|
| 26 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, MistralConfig, MinistralModel, AutoModel
|
| 27 |
|
| 28 |
+
# Explicitly register 'ministral3' model type to MistralConfig for the nested text config
|
| 29 |
+
class RegistrableMinistralConfig(MistralConfig): # Subclass from MistralConfig (base)
|
|
|
|
| 30 |
model_type = "ministral3"
|
| 31 |
|
| 32 |
AutoConfig.register("ministral3", RegistrableMinistralConfig)
|
| 33 |
print("๐ง Registered 'ministral3' to RegistrableMinistralConfig.")
|
| 34 |
|
| 35 |
+
# Register RegistrableMinistralConfig with AutoModel so Mistral3Model can load its language_model
|
| 36 |
+
try:
|
| 37 |
+
AutoModel.register(RegistrableMinistralConfig, MinistralModel)
|
| 38 |
+
print("๐ง Registered RegistrableMinistralConfig to MinistralModel for AutoModel.")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f" โ Failed to register RegistrableMinistralConfig with AutoModel: {e}")
|
| 41 |
+
|
| 42 |
+
# Register Mistral3Config to its model class for AutoModelForCausalLM
|
| 43 |
print("๐ง Registering Mistral3 model class with AutoModelForCausalLM...")
|
| 44 |
try:
|
| 45 |
from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
|
|
|
|
| 48 |
print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration")
|
| 49 |
except ImportError as e:
|
| 50 |
print(f" โ Failed to import Mistral3 modeling classes: {e}")
|
| 51 |
+
print(" Attempting fallback registration for Mistral3Config with standard MistralForCausalLM.")
|
|
|
|
| 52 |
from transformers import MistralForCausalLM
|
| 53 |
try:
|
| 54 |
AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
|
|
|
|
| 183 |
trainer.train()
|
| 184 |
|
| 185 |
print("๐พ Pushing to Hub...")
|
| 186 |
+
trainer.push_to_hub()
|