sunkencity commited on
Commit
20566c4
Β·
verified Β·
1 Parent(s): 037cd7b

Upload train_aviation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_aviation.py +19 -22
train_aviation.py CHANGED
@@ -32,47 +32,46 @@ from transformers import (
32
  )
33
 
34
  # ------------------------------------------------------------------
35
- # CRITICAL FIX: Manually wire the Ministral3 Inner Model
36
  # ------------------------------------------------------------------
37
  print("πŸ”§ Starting Manual Registration/Wiring...")
38
 
39
  try:
40
- # 1. Import the specific classes for Ministral (Inner Text Model)
41
  from transformers.models.ministral.configuration_ministral import MinistralConfig
42
  from transformers.models.ministral.modeling_ministral import MinistralModel
43
 
44
- print(" βœ… Found native MinistralConfig and MinistralModel")
45
-
46
- # 2. Create a Compatibility Config Class
47
  class Ministral3CompatConfig(MinistralConfig):
48
- model_type = "ministral3" # Match the JSON
49
-
50
  def __init__(self, **kwargs):
51
  super().__init__(**kwargs)
52
- # Inject missing attributes causing crashes
53
  if not hasattr(self, 'sliding_window') or self.sliding_window is None:
54
  self.sliding_window = 4096
55
  if not hasattr(self, 'layer_types'):
56
- # Default to sliding_attention for all layers if not specified
57
  self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40)
58
 
59
- # 3. Create a Compatibility Model Class
60
- # This is required to satisfy the check: model.config_class == config_class
61
  class Ministral3CompatModel(MinistralModel):
62
  config_class = Ministral3CompatConfig
63
 
64
- # 4. Register Config with AutoConfig
65
  AutoConfig.register("ministral3", Ministral3CompatConfig)
66
- print(" βœ… Registered AutoConfig: 'ministral3' -> Ministral3CompatConfig")
67
-
68
- # 5. Register Model with AutoModel
69
- # Now this should pass because Ministral3CompatModel.config_class matches Ministral3CompatConfig
70
  AutoModel.register(Ministral3CompatConfig, Ministral3CompatModel)
71
- print(" βœ… Registered AutoModel: Ministral3CompatConfig -> Ministral3CompatModel")
 
 
 
 
 
 
 
 
 
72
 
73
  except ImportError as e:
74
- print(f" ❌ Failed to import Ministral classes: {e}")
75
- print(" ⚠️ This usually means the transformers version is too old for Ministral-3.")
76
 
77
  # ------------------------------------------------------------------
78
  # Standard Training Setup
@@ -113,8 +112,6 @@ bnb_config = BitsAndBytesConfig(
113
  )
114
 
115
  print(f"πŸ€– Loading model {model_id}...")
116
- # We use AutoModelForCausalLM, which should now handle the outer Mistral3Config
117
- # and recursively handle the inner Ministral3CompatConfig via our registration above.
118
  model = AutoModelForCausalLM.from_pretrained(
119
  model_id,
120
  quantization_config=bnb_config,
@@ -175,4 +172,4 @@ print("πŸš€ Starting training...")
175
  trainer.train()
176
 
177
  print("πŸ’Ύ Pushing to Hub...")
178
- trainer.push_to_hub()
 
32
  )
33
 
34
  # ------------------------------------------------------------------
35
+ # CRITICAL FIX: Manually wire the Ministral3 Hierarchy
36
  # ------------------------------------------------------------------
37
  print("πŸ”§ Starting Manual Registration/Wiring...")
38
 
39
  try:
40
+ # --- 1. Inner Text Model (Ministral) ---
41
  from transformers.models.ministral.configuration_ministral import MinistralConfig
42
  from transformers.models.ministral.modeling_ministral import MinistralModel
43
 
44
+ # Compatibility Config for Inner Model
 
 
45
  class Ministral3CompatConfig(MinistralConfig):
46
+ model_type = "ministral3"
 
47
  def __init__(self, **kwargs):
48
  super().__init__(**kwargs)
 
49
  if not hasattr(self, 'sliding_window') or self.sliding_window is None:
50
  self.sliding_window = 4096
51
  if not hasattr(self, 'layer_types'):
 
52
  self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40)
53
 
54
+ # Compatibility Model for Inner Model
 
55
  class Ministral3CompatModel(MinistralModel):
56
  config_class = Ministral3CompatConfig
57
 
58
+ # Register Inner Components
59
  AutoConfig.register("ministral3", Ministral3CompatConfig)
 
 
 
 
60
  AutoModel.register(Ministral3CompatConfig, Ministral3CompatModel)
61
+ print(" βœ… Registered Inner: 'ministral3' -> Ministral3CompatModel")
62
+
63
+
64
+ # --- 2. Outer Multimodal Model (Mistral3) ---
65
+ from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
66
+ from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
67
+
68
+ # Register Outer Components with AutoModelForCausalLM
69
+ AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration)
70
+ print(" βœ… Registered Outer: Mistral3Config -> Mistral3ForConditionalGeneration")
71
 
72
  except ImportError as e:
73
+ print(f" ❌ Failed to import/register classes: {e}")
74
+ print(" ⚠️ This usually means the transformers version is too old or incompatible.")
75
 
76
  # ------------------------------------------------------------------
77
  # Standard Training Setup
 
112
  )
113
 
114
  print(f"πŸ€– Loading model {model_id}...")
 
 
115
  model = AutoModelForCausalLM.from_pretrained(
116
  model_id,
117
  quantization_config=bnb_config,
 
172
  trainer.train()
173
 
174
  print("πŸ’Ύ Pushing to Hub...")
175
+ trainer.push_to_hub()