sunkencity commited on
Commit
037cd7b
Β·
verified Β·
1 Parent(s): 3b8ec8c

Upload train_aviation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_aviation.py +13 -9
train_aviation.py CHANGED
@@ -28,7 +28,7 @@ from transformers import (
28
  BitsAndBytesConfig,
29
  AutoConfig,
30
  AutoModel,
31
- MistralConfig # Standard Mistral
32
  )
33
 
34
  # ------------------------------------------------------------------
@@ -38,14 +38,12 @@ print("πŸ”§ Starting Manual Registration/Wiring...")
38
 
39
  try:
40
  # 1. Import the specific classes for Ministral (Inner Text Model)
41
- # The traceback confirmed these exist in the installed transformers version
42
  from transformers.models.ministral.configuration_ministral import MinistralConfig
43
  from transformers.models.ministral.modeling_ministral import MinistralModel
44
 
45
  print(" βœ… Found native MinistralConfig and MinistralModel")
46
 
47
  # 2. Create a Compatibility Config Class
48
- # The hub config says "model_type": "ministral3", but code expects attributes not in the JSON.
49
  class Ministral3CompatConfig(MinistralConfig):
50
  model_type = "ministral3" # Match the JSON
51
 
@@ -58,14 +56,19 @@ try:
58
  # Default to sliding_attention for all layers if not specified
59
  self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40)
60
 
61
- # 3. Register Config with AutoConfig (So it handles "model_type": "ministral3")
 
 
 
 
 
62
  AutoConfig.register("ministral3", Ministral3CompatConfig)
63
  print(" βœ… Registered AutoConfig: 'ministral3' -> Ministral3CompatConfig")
64
 
65
- # 4. Register Model with AutoModel (So AutoModel.from_config knows what to build)
66
- # THIS WAS THE MISSING PIECE causing "Unrecognized configuration class"
67
- AutoModel.register(Ministral3CompatConfig, MinistralModel)
68
- print(" βœ… Registered AutoModel: Ministral3CompatConfig -> MinistralModel")
69
 
70
  except ImportError as e:
71
  print(f" ❌ Failed to import Ministral classes: {e}")
@@ -158,6 +161,7 @@ config = SFTConfig(
158
  dataset_kwargs={"add_special_tokens": False}
159
  )
160
 
 
161
  trainer = SFTTrainer(
162
  model=model,
163
  train_dataset=train_dataset,
@@ -171,4 +175,4 @@ print("πŸš€ Starting training...")
171
  trainer.train()
172
 
173
  print("πŸ’Ύ Pushing to Hub...")
174
- trainer.push_to_hub()
 
28
  BitsAndBytesConfig,
29
  AutoConfig,
30
  AutoModel,
31
+ MistralConfig
32
  )
33
 
34
  # ------------------------------------------------------------------
 
38
 
39
  try:
40
  # 1. Import the specific classes for Ministral (Inner Text Model)
 
41
  from transformers.models.ministral.configuration_ministral import MinistralConfig
42
  from transformers.models.ministral.modeling_ministral import MinistralModel
43
 
44
  print(" βœ… Found native MinistralConfig and MinistralModel")
45
 
46
  # 2. Create a Compatibility Config Class
 
47
  class Ministral3CompatConfig(MinistralConfig):
48
  model_type = "ministral3" # Match the JSON
49
 
 
56
  # Default to sliding_attention for all layers if not specified
57
  self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40)
58
 
59
+ # 3. Create a Compatibility Model Class
60
+ # This is required to satisfy the check: model.config_class == config_class
61
+ class Ministral3CompatModel(MinistralModel):
62
+ config_class = Ministral3CompatConfig
63
+
64
+ # 4. Register Config with AutoConfig
65
  AutoConfig.register("ministral3", Ministral3CompatConfig)
66
  print(" βœ… Registered AutoConfig: 'ministral3' -> Ministral3CompatConfig")
67
 
68
+ # 5. Register Model with AutoModel
69
+ # Now this should pass because Ministral3CompatModel.config_class matches Ministral3CompatConfig
70
+ AutoModel.register(Ministral3CompatConfig, Ministral3CompatModel)
71
+ print(" βœ… Registered AutoModel: Ministral3CompatConfig -> Ministral3CompatModel")
72
 
73
  except ImportError as e:
74
  print(f" ❌ Failed to import Ministral classes: {e}")
 
161
  dataset_kwargs={"add_special_tokens": False}
162
  )
163
 
164
+ # Trainer
165
  trainer = SFTTrainer(
166
  model=model,
167
  train_dataset=train_dataset,
 
175
  trainer.train()
176
 
177
  print("πŸ’Ύ Pushing to Hub...")
178
+ trainer.push_to_hub()