sunkencity commited on
Commit
1ae6cb4
·
verified ·
1 Parent(s): 6f6fc96

Upload train_aviation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_aviation.py +21 -14
train_aviation.py CHANGED
@@ -42,29 +42,36 @@ from trl import SFTTrainer, SFTConfig
42
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
43
 
44
  # Register 'ministral3' config to handle nested text_config
45
- # ... (rest of registration logic)
46
  print("🔧 Registering ministral3 config (Monkey Patch Strategy)...")
47
  try:
48
  from transformers import MinistralConfig, AutoConfig
49
 
50
- # Monkey patch the model_type to match what the config.json has
51
- # This allows us to use the native class which is already registered with AutoModel
52
- print(f" Original MinistralConfig.model_type: {MinistralConfig.model_type}")
53
- MinistralConfig.model_type = "ministral3"
54
- print(f" Patched MinistralConfig.model_type: {MinistralConfig.model_type}")
55
 
56
- # Register the patched class for the "ministral3" key
57
- AutoConfig.register("ministral3", MinistralConfig)
58
- print(" Registered ministral3 -> MinistralConfig (native, patched)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  except Exception as e:
61
  print(f" ❌ Failed to patch/register ministral3 config: {e}")
62
 
63
  # Register Mistral3Config to a model class
64
- # ... (rest of registration kept as is)
65
- # ... (rest of registration kept as is)
66
- # ... (rest of registration kept as is)
67
- # ... (rest of registration kept as is)
68
  print("🔧 Registering Mistral3 model class...")
69
  try:
70
  from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
@@ -193,4 +200,4 @@ print("🚀 Starting training...")
193
  trainer.train()
194
 
195
  print("💾 Pushing to Hub...")
196
- trainer.push_to_hub()
 
42
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
43
 
44
  # Register 'ministral3' config to handle nested text_config
 
45
  print("🔧 Registering ministral3 config (Monkey Patch Strategy)...")
46
  try:
47
  from transformers import MinistralConfig, AutoConfig
48
 
49
+ # We need to ensure MinistralConfig has sliding_window and layer_types if it's used
50
+ # as the inner text_config for Mistral3.
51
+ # Create a temporary compatible class.
 
 
52
 
53
+ class Ministral3CompatConfig(MinistralConfig):
54
+ model_type = "ministral3" # Ensure this matches the `text_config["model_type"]`
55
+ def __init__(self, **kwargs):
56
+ super().__init__(**kwargs)
57
+ # Ensure sliding_window is set, if null in config.json or missing
58
+ if not hasattr(self, 'sliding_window') or self.sliding_window is None:
59
+ self.sliding_window = 4096 # Default value for Mistral/Ministral models
60
+
61
+ # Ensure layer_types is set, as it's expected by modeling_ministral.py
62
+ if not hasattr(self, 'layer_types'):
63
+ # Assumes all layers are sliding attention if the model uses it
64
+ # Use getattr for num_hidden_layers as it might not be set yet if config is partial
65
+ self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40) # Default to 40 if not found
66
+
67
+ # Register the compatible class for the "ministral3" key
68
+ AutoConfig.register("ministral3", Ministral3CompatConfig)
69
+ print(" Registered ministral3 -> Ministral3CompatConfig (patched)")
70
 
71
  except Exception as e:
72
  print(f" ❌ Failed to patch/register ministral3 config: {e}")
73
 
74
  # Register Mistral3Config to a model class
 
 
 
 
75
  print("🔧 Registering Mistral3 model class...")
76
  try:
77
  from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
 
200
  trainer.train()
201
 
202
  print("💾 Pushing to Hub...")
203
+ trainer.push_to_hub()