Upload train_aviation.py with huggingface_hub
Browse files- train_aviation.py +14 -10
train_aviation.py
CHANGED
|
@@ -39,17 +39,21 @@ from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
|
|
| 39 |
from trl import SFTTrainer, SFTConfig
|
| 40 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
|
| 41 |
|
| 42 |
-
# Register
|
| 43 |
-
print("🔧 Registering
|
| 44 |
try:
|
| 45 |
-
from transformers import
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# Model ID
|
| 55 |
# model_id defined above
|
|
|
|
| 39 |
from trl import SFTTrainer, SFTConfig
|
| 40 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
|
| 41 |
|
| 42 |
+
# Register Mistral3Config to a model class
|
| 43 |
+
print("🔧 Registering Mistral3 model class...")
|
| 44 |
try:
|
| 45 |
+
from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
|
| 46 |
+
try:
|
| 47 |
+
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
|
| 48 |
+
AutoModelForCausalLM.register(Mistral3Config, Mistral3ForConditionalGeneration)
|
| 49 |
+
print(" Registered Mistral3Config -> Mistral3ForConditionalGeneration")
|
| 50 |
+
except ImportError:
|
| 51 |
+
print(" Mistral3ForConditionalGeneration not found, trying MistralForCausalLM")
|
| 52 |
+
from transformers import MistralForCausalLM
|
| 53 |
+
AutoModelForCausalLM.register(Mistral3Config, MistralForCausalLM)
|
| 54 |
+
print(" Registered Mistral3Config -> MistralForCausalLM")
|
| 55 |
+
except ImportError as e:
|
| 56 |
+
print(f" ❌ Failed to find Mistral3Config or register model: {e}")
|
| 57 |
|
| 58 |
# Model ID
|
| 59 |
# model_id defined above
|