Upload train_aviation.py with huggingface_hub
Browse files- train_aviation.py +21 -14
train_aviation.py
CHANGED
|
@@ -42,29 +42,36 @@ from trl import SFTTrainer, SFTConfig
|
|
| 42 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
|
| 43 |
|
| 44 |
# Register 'ministral3' config to handle nested text_config
|
| 45 |
-
# ... (rest of registration logic)
|
| 46 |
print("🔧 Registering ministral3 config (Monkey Patch Strategy)...")
|
| 47 |
try:
|
| 48 |
from transformers import MinistralConfig, AutoConfig
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
#
|
| 52 |
-
|
| 53 |
-
MinistralConfig.model_type = "ministral3"
|
| 54 |
-
print(f" Patched MinistralConfig.model_type: {MinistralConfig.model_type}")
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
except Exception as e:
|
| 61 |
print(f" ❌ Failed to patch/register ministral3 config: {e}")
|
| 62 |
|
| 63 |
# Register Mistral3Config to a model class
|
| 64 |
-
# ... (rest of registration kept as is)
|
| 65 |
-
# ... (rest of registration kept as is)
|
| 66 |
-
# ... (rest of registration kept as is)
|
| 67 |
-
# ... (rest of registration kept as is)
|
| 68 |
print("🔧 Registering Mistral3 model class...")
|
| 69 |
try:
|
| 70 |
from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
|
|
@@ -193,4 +200,4 @@ print("🚀 Starting training...")
|
|
| 193 |
trainer.train()
|
| 194 |
|
| 195 |
print("💾 Pushing to Hub...")
|
| 196 |
-
trainer.push_to_hub()
|
|
|
|
| 42 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
|
| 43 |
|
| 44 |
# Register 'ministral3' config to handle nested text_config
|
|
|
|
| 45 |
print("🔧 Registering ministral3 config (Monkey Patch Strategy)...")
|
| 46 |
try:
|
| 47 |
from transformers import MinistralConfig, AutoConfig
|
| 48 |
|
| 49 |
+
# We need to ensure MinistralConfig has sliding_window and layer_types if it's used
|
| 50 |
+
# as the inner text_config for Mistral3.
|
| 51 |
+
# Create a temporary compatible class.
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
class Ministral3CompatConfig(MinistralConfig):
|
| 54 |
+
model_type = "ministral3" # Ensure this matches the `text_config["model_type"]`
|
| 55 |
+
def __init__(self, **kwargs):
|
| 56 |
+
super().__init__(**kwargs)
|
| 57 |
+
# Ensure sliding_window is set, if null in config.json or missing
|
| 58 |
+
if not hasattr(self, 'sliding_window') or self.sliding_window is None:
|
| 59 |
+
self.sliding_window = 4096 # Default value for Mistral/Ministral models
|
| 60 |
+
|
| 61 |
+
# Ensure layer_types is set, as it's expected by modeling_ministral.py
|
| 62 |
+
if not hasattr(self, 'layer_types'):
|
| 63 |
+
# Assumes all layers are sliding attention if the model uses it
|
| 64 |
+
# Use getattr for num_hidden_layers as it might not be set yet if config is partial
|
| 65 |
+
self.layer_types = ["sliding_attention"] * getattr(self, "num_hidden_layers", 40) # Default to 40 if not found
|
| 66 |
+
|
| 67 |
+
# Register the compatible class for the "ministral3" key
|
| 68 |
+
AutoConfig.register("ministral3", Ministral3CompatConfig)
|
| 69 |
+
print(" Registered ministral3 -> Ministral3CompatConfig (patched)")
|
| 70 |
|
| 71 |
except Exception as e:
|
| 72 |
print(f" ❌ Failed to patch/register ministral3 config: {e}")
|
| 73 |
|
| 74 |
# Register Mistral3Config to a model class
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
print("🔧 Registering Mistral3 model class...")
|
| 76 |
try:
|
| 77 |
from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
|
|
|
|
| 200 |
trainer.train()
|
| 201 |
|
| 202 |
print("💾 Pushing to Hub...")
|
| 203 |
+
trainer.push_to_hub()
|