sunkencity commited on
Commit
f2890b3
·
verified ·
1 Parent(s): d083c2d

Upload train_aviation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_aviation.py +19 -1
train_aviation.py CHANGED
@@ -33,6 +33,24 @@ class RegistrableMinistralConfig(MistralConfig):
33
  AutoConfig.register("ministral3", RegistrableMinistralConfig)
34
  print("🔧 Registered 'ministral3' to RegistrableMinistralConfig.")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # Load dataset
38
  print("📦 Loading dataset...")
@@ -160,4 +178,4 @@ print("🚀 Starting training...")
160
  trainer.train()
161
 
162
  print("💾 Pushing to Hub...")
163
- trainer.push_to_hub()
 
33
  AutoConfig.register("ministral3", RegistrableMinistralConfig)
34
  print("🔧 Registered 'ministral3' to RegistrableMinistralConfig.")
35
 
36
+ # Register Mistral3Config to a model class
37
+ print("🔧 Registering Mistral3 model class with AutoModelForCausalLM...")
38
+ try:
39
+ from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
40
+ from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
41
+ AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration)
42
+ print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration")
43
+ except ImportError as e:
44
+ print(f" ❌ Failed to import Mistral3 modeling classes: {e}")
45
+ # Fallback if specific Mistral3 classes are not directly importable
46
+ print(" Trying to register Mistral3Config to standard MistralForCausalLM as fallback.")
47
+ from transformers import MistralForCausalLM
48
+ try:
49
+ AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
50
+ print(" Registered Mistral3Config -> MistralForCausalLM (fallback)")
51
+ except Exception as fallback_e:
52
+ print(f" ❌ Fallback registration also failed: {fallback_e}")
53
+
54
 
55
  # Load dataset
56
  print("📦 Loading dataset...")
 
178
  trainer.train()
179
 
180
  print("💾 Pushing to Hub...")
181
+ trainer.push_to_hub()