sunkencity commited on
Commit
6cb259e
Β·
verified Β·
1 Parent(s): 1ae6cb4

Upload train_aviation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_aviation.py +56 -53
train_aviation.py CHANGED
@@ -18,13 +18,15 @@ import os
18
  from huggingface_hub import list_repo_files
19
 
20
  # DEBUG: Check token and repo access
 
21
  # print("πŸ” DIAGNOSTICS:")
22
  # token = os.environ.get("HF_TOKEN")
23
  # print(f" HF_TOKEN env var present: {bool(token)}")
24
  # if token:
25
  # print(f" HF_TOKEN prefix: {token[:4]}...")
26
 
27
- # model_id = "mistralai/Ministral-3-14B-Reasoning-2512"
 
28
  # try:
29
  # print(f" Attempting to list files for {model_id}...")
30
  # files = list_repo_files(model_id, token=token)
@@ -34,61 +36,45 @@ from huggingface_hub import list_repo_files
34
  # print(f" ❌ Failed to list repo files: {e}")
35
  # print("="*40)
36
 
37
- model_id = "mistralai/Ministral-3-14B-Reasoning-2512"
38
 
39
  from datasets import load_dataset
40
  from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
41
  from trl import SFTTrainer, SFTConfig
42
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
43
 
44
- # Register 'ministral3' config to handle nested text_config
45
- print("πŸ”§ Registering ministral3 config (Monkey Patch Strategy)...")
46
- try:
47
- from transformers import MinistralConfig, AutoConfig
48
-
49
- # We need to ensure MinistralConfig has sliding_window and layer_types if it's used
50
- # as the inner text_config for Mistral3.
51
- # Create a temporary compatible class.
52
-
53
- class Ministral3CompatConfig(MinistralConfig):
54
- model_type = "ministral3" # Ensure this matches the `text_config["model_type"]`
55
- def __init__(self, **kwargs):
56
- super().__init__(**kwargs)
57
- # Ensure sliding_window is set, if null in config.json or missing
58
- if not hasattr(self, 'sliding_window') or self.sliding_window is None:
59
- self.sliding_window = 4096 # Default value for Mistral/Ministral models
60
-
61
- # Ensure layer_types is set, as it's expected by modeling_ministral.py
62
- if not hasattr(self, 'layer_types'):
63
- # Assumes all layers are sliding attention if the model uses it
64
- # Use getattr for num_hidden_layers as it might not be set yet if config is partial
65
- self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40) # Default to 40 if not found
66
-
67
- # Register the compatible class for the "ministral3" key
68
- AutoConfig.register("ministral3", Ministral3CompatConfig)
69
- print(" Registered ministral3 -> Ministral3CompatConfig (patched)")
70
-
71
- except Exception as e:
72
- print(f" ❌ Failed to patch/register ministral3 config: {e}")
73
-
74
- # Register Mistral3Config to a model class
75
- print("πŸ”§ Registering Mistral3 model class...")
76
- try:
77
- from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
78
- try:
79
- from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
80
- AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration)
81
- print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration")
82
- except ImportError:
83
- print(" Mistral3ForConditionalGeneration not found, trying MistralForCausalLM")
84
- from transformers import MistralForCausalLM
85
- AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
86
- print(" Registered Mistral3Config -> MistralForCausalLM")
87
- except ImportError as e:
88
- print(f" ❌ Failed to find Mistral3Config or register model: {e}")
89
-
90
- # Model ID
91
- # model_id defined above
92
 
93
 
94
  # Load dataset
@@ -134,10 +120,27 @@ bnb_config = BitsAndBytesConfig(
134
  bnb_4bit_use_double_quant=True,
135
  )
136
 
137
- # Load Model
138
- print(f"πŸ€– Loading model {model_id}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  model = AutoModelForCausalLM.from_pretrained(
140
  model_id,
 
141
  quantization_config=bnb_config,
142
  device_map="auto",
143
  torch_dtype=torch.bfloat16,
@@ -200,4 +203,4 @@ print("πŸš€ Starting training...")
200
  trainer.train()
201
 
202
  print("πŸ’Ύ Pushing to Hub...")
203
- trainer.push_to_hub()
 
18
  from huggingface_hub import list_repo_files
19
 
20
  # DEBUG: Check token and repo access
21
+ # (commented out for cleaner logs now that it works)
22
  # print("πŸ” DIAGNOSTICS:")
23
  # token = os.environ.get("HF_TOKEN")
24
  # print(f" HF_TOKEN env var present: {bool(token)}")
25
  # if token:
26
  # print(f" HF_TOKEN prefix: {token[:4]}...")
27
 
28
+ model_id = "mistralai/Ministral-3-14B-Reasoning-2512" # Defined at top level
29
+
30
  # try:
31
  # print(f" Attempting to list files for {model_id}...")
32
  # files = list_repo_files(model_id, token=token)
 
36
  # print(f" ❌ Failed to list repo files: {e}")
37
  # print("="*40)
38
 
 
39
 
40
  from datasets import load_dataset
41
  from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
42
  from trl import SFTTrainer, SFTConfig
43
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
44
 
45
+ # Register 'ministral3' config to handle nested text_config (Removed, as patching directly)
46
+ # This whole section is being replaced by direct config patching below.
47
+ # print("πŸ”§ Registering ministral3 config (Monkey Patch Strategy)...")
48
+ # try:
49
+ # from transformers import MinistralConfig, AutoConfig
50
+ # class Ministral3CompatConfig(MinistralConfig):
51
+ # model_type = "ministral3" # Ensure this matches the `text_config["model_type"]`
52
+ # def __init__(self, **kwargs):
53
+ # super().__init__(**kwargs)
54
+ # if not hasattr(self, 'sliding_window') or self.sliding_window is None:
55
+ # self.sliding_window = 4096
56
+ # if not hasattr(self, 'layer_types'):
57
+ # self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40)
58
+ # AutoConfig.register("ministral3", Ministral3CompatConfig)
59
+ # print(" Registered ministral3 -> Ministral3CompatConfig (patched)")
60
+ # except Exception as e:
61
+ # print(f" ❌ Failed to patch/register ministral3 config: {e}")
62
+
63
+ # Register Mistral3Config to a model class (Removed, not needed with direct patching)
64
+ # print("πŸ”§ Registering Mistral3 model class...")
65
+ # try:
66
+ # from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
67
+ # try:
68
+ # from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
69
+ # AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration)
70
+ # print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration")
71
+ # except ImportError:
72
+ # print(" Mistral3ForConditionalGeneration not found, trying MistralForCausalLM")
73
+ # from transformers import MistralForCausalLM
74
+ # AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
75
+ # print(" Registered Mistral3Config -> MistralForCausalLM")
76
+ # except ImportError as e:
77
+ # print(f" ❌ Failed to find Mistral3Config or register model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # Load dataset
 
120
  bnb_4bit_use_double_quant=True,
121
  )
122
 
123
+ # Load config first to patch
124
+ print(f"πŸ€– Loading config for {model_id}...")
125
+ config = AutoConfig.from_pretrained(model_id)
126
+
127
+ # Patch text_config to include sliding_window and layer_types
128
+ print("πŸ”§ Patching config.text_config...")
129
+ if hasattr(config, 'text_config'):
130
+ if not hasattr(config.text_config, 'sliding_window') or config.text_config.sliding_window is None:
131
+ config.text_config.sliding_window = 4096
132
+ print(" Set config.text_config.sliding_window = 4096")
133
+ if not hasattr(config.text_config, 'layer_types'):
134
+ config.text_config.layer_types = ["sliding_attention"] * getattr(config.text_config, "num_hidden_layers", 40)
135
+ print(" Set config.text_config.layer_types")
136
+ else:
137
+ print(" No text_config found, skipping patching.")
138
+
139
+ # Load Model with the patched config
140
+ print(f"πŸ€– Loading model {model_id} with patched config...")
141
  model = AutoModelForCausalLM.from_pretrained(
142
  model_id,
143
+ config=config, # Pass the patched config
144
  quantization_config=bnb_config,
145
  device_map="auto",
146
  torch_dtype=torch.bfloat16,
 
203
  trainer.train()
204
 
205
  print("πŸ’Ύ Pushing to Hub...")
206
+ trainer.push_to_hub()