sunkencity commited on
Commit
7c82449
Β·
verified Β·
1 Parent(s): 6cb259e

Upload train_aviation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_aviation.py +19 -51
train_aviation.py CHANGED
@@ -3,7 +3,7 @@
3
  # "torch",
4
  # "trl>=0.12.0",
5
  # "peft>=0.7.0",
6
- # "transformers>=4.46.0",
7
  # "huggingface_hub>=0.26.0",
8
  # "accelerate>=0.24.0",
9
  # "trackio",
@@ -42,39 +42,7 @@ from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
42
  from trl import SFTTrainer, SFTConfig
43
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
44
 
45
- # Register 'ministral3' config to handle nested text_config (Removed, as patching directly)
46
- # This whole section is being replaced by direct config patching below.
47
- # print("πŸ”§ Registering ministral3 config (Monkey Patch Strategy)...")
48
- # try:
49
- # from transformers import MinistralConfig, AutoConfig
50
- # class Ministral3CompatConfig(MinistralConfig):
51
- # model_type = "ministral3" # Ensure this matches the `text_config["model_type"]`
52
- # def __init__(self, **kwargs):
53
- # super().__init__(**kwargs)
54
- # if not hasattr(self, 'sliding_window') or self.sliding_window is None:
55
- # self.sliding_window = 4096
56
- # if not hasattr(self, 'layer_types'):
57
- # self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40)
58
- # AutoConfig.register("ministral3", Ministral3CompatConfig)
59
- # print(" Registered ministral3 -> Ministral3CompatConfig (patched)")
60
- # except Exception as e:
61
- # print(f" ❌ Failed to patch/register ministral3 config: {e}")
62
-
63
- # Register Mistral3Config to a model class (Removed, not needed with direct patching)
64
- # print("πŸ”§ Registering Mistral3 model class...")
65
- # try:
66
- # from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
67
- # try:
68
- # from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
69
- # AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration)
70
- # print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration")
71
- # except ImportError:
72
- # print(" Mistral3ForConditionalGeneration not found, trying MistralForCausalLM")
73
- # from transformers import MistralForCausalLM
74
- # AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
75
- # print(" Registered Mistral3Config -> MistralForCausalLM")
76
- # except ImportError as e:
77
- # print(f" ❌ Failed to find Mistral3Config or register model: {e}")
78
 
79
 
80
  # Load dataset
@@ -120,27 +88,27 @@ bnb_config = BitsAndBytesConfig(
120
  bnb_4bit_use_double_quant=True,
121
  )
122
 
123
- # Load config first to patch
124
  print(f"πŸ€– Loading config for {model_id}...")
125
  config = AutoConfig.from_pretrained(model_id)
126
 
127
- # Patch text_config to include sliding_window and layer_types
128
- print("πŸ”§ Patching config.text_config...")
129
- if hasattr(config, 'text_config'):
130
- if not hasattr(config.text_config, 'sliding_window') or config.text_config.sliding_window is None:
131
- config.text_config.sliding_window = 4096
132
- print(" Set config.text_config.sliding_window = 4096")
133
- if not hasattr(config.text_config, 'layer_types'):
134
- config.text_config.layer_types = ["sliding_attention"] * getattr(config.text_config, "num_hidden_layers", 40)
135
- print(" Set config.text_config.layer_types")
136
- else:
137
- print(" No text_config found, skipping patching.")
138
-
139
- # Load Model with the patched config
140
- print(f"πŸ€– Loading model {model_id} with patched config...")
141
  model = AutoModelForCausalLM.from_pretrained(
142
  model_id,
143
- config=config, # Pass the patched config
144
  quantization_config=bnb_config,
145
  device_map="auto",
146
  torch_dtype=torch.bfloat16,
@@ -203,4 +171,4 @@ print("πŸš€ Starting training...")
203
  trainer.train()
204
 
205
  print("πŸ’Ύ Pushing to Hub...")
206
- trainer.push_to_hub()
 
3
  # "torch",
4
  # "trl>=0.12.0",
5
  # "peft>=0.7.0",
6
+ # "transformers", # Let UV pick latest
7
  # "huggingface_hub>=0.26.0",
8
  # "accelerate>=0.24.0",
9
  # "trackio",
 
42
  from trl import SFTTrainer, SFTConfig
43
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
44
 
45
+ # All custom config registration logic removed, relying on latest transformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  # Load dataset
 
88
  bnb_4bit_use_double_quant=True,
89
  )
90
 
91
+ # Load config first (AutoConfig should handle it now with latest transformers)
92
  print(f"πŸ€– Loading config for {model_id}...")
93
  config = AutoConfig.from_pretrained(model_id)
94
 
95
+ # Patch text_config to include sliding_window and layer_types (Now unnecessary, should be handled by latest transformers)
96
+ # print("πŸ”§ Patching config.text_config...")
97
+ # if hasattr(config, 'text_config'):
98
+ # if not hasattr(config.text_config, 'sliding_window') or config.text_config.sliding_window is None:
99
+ # config.text_config.sliding_window = 4096
100
+ # print(" Set config.text_config.sliding_window = 4096")
101
+ # if not hasattr(config.text_config, 'layer_types'):
102
+ # config.text_config.layer_types = ["sliding_attention"] * getattr(config.text_config, "num_hidden_layers", 40)
103
+ # print(" Set config.text_config.layer_types")
104
+ # else:
105
+ # print(" No text_config found, skipping patching.")
106
+
107
+ # Load Model with the config
108
+ print(f"πŸ€– Loading model {model_id} with config...")
109
  model = AutoModelForCausalLM.from_pretrained(
110
  model_id,
111
+ config=config, # Pass the config
112
  quantization_config=bnb_config,
113
  device_map="auto",
114
  torch_dtype=torch.bfloat16,
 
171
  trainer.train()
172
 
173
  print("πŸ’Ύ Pushing to Hub...")
174
+ trainer.push_to_hub()