Upload train_aviation.py with huggingface_hub
Browse files- train_aviation.py +19 -1
train_aviation.py
CHANGED
|
@@ -33,6 +33,24 @@ class RegistrableMinistralConfig(MistralConfig):
|
|
| 33 |
AutoConfig.register("ministral3", RegistrableMinistralConfig)
|
| 34 |
print("🔧 Registered 'ministral3' to RegistrableMinistralConfig.")
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# Load dataset
|
| 38 |
print("📦 Loading dataset...")
|
|
@@ -160,4 +178,4 @@ print("🚀 Starting training...")
|
|
| 160 |
trainer.train()
|
| 161 |
|
| 162 |
print("💾 Pushing to Hub...")
|
| 163 |
-
trainer.push_to_hub()
|
|
|
|
| 33 |
AutoConfig.register("ministral3", RegistrableMinistralConfig)
|
| 34 |
print("🔧 Registered 'ministral3' to RegistrableMinistralConfig.")
|
| 35 |
|
| 36 |
+
# Register Mistral3Config to a model class
|
| 37 |
+
print("🔧 Registering Mistral3 model class with AutoModelForCausalLM...")
|
| 38 |
+
try:
|
| 39 |
+
from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
|
| 40 |
+
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
|
| 41 |
+
AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration)
|
| 42 |
+
print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration")
|
| 43 |
+
except ImportError as e:
|
| 44 |
+
print(f" ❌ Failed to import Mistral3 modeling classes: {e}")
|
| 45 |
+
# Fallback if specific Mistral3 classes are not directly importable
|
| 46 |
+
print(" Trying to register Mistral3Config to standard MistralForCausalLM as fallback.")
|
| 47 |
+
from transformers import MistralForCausalLM
|
| 48 |
+
try:
|
| 49 |
+
AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
|
| 50 |
+
print(" Registered Mistral3Config -> MistralForCausalLM (fallback)")
|
| 51 |
+
except Exception as fallback_e:
|
| 52 |
+
print(f" ❌ Fallback registration also failed: {fallback_e}")
|
| 53 |
+
|
| 54 |
|
| 55 |
# Load dataset
|
| 56 |
print("📦 Loading dataset...")
|
|
|
|
| 178 |
trainer.train()
|
| 179 |
|
| 180 |
print("💾 Pushing to Hub...")
|
| 181 |
+
trainer.push_to_hub()
|