Update modeling.py
Browse files- modeling.py +13 -7
modeling.py
CHANGED
|
@@ -11,7 +11,7 @@ from transformers import (
|
|
| 11 |
|
| 12 |
|
| 13 |
# ============================================================
|
| 14 |
-
# CONFIG CLASS
|
| 15 |
# ============================================================
|
| 16 |
class EmoAxisConfig(PretrainedConfig):
|
| 17 |
model_type = "emoaxis"
|
|
@@ -21,15 +21,20 @@ class EmoAxisConfig(PretrainedConfig):
|
|
| 21 |
self.num_labels = num_labels
|
| 22 |
self.base_model_name = base_model_name
|
| 23 |
|
| 24 |
-
# 🚨
|
|
|
|
| 25 |
@classmethod
|
| 26 |
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 27 |
-
# Filter arguments that
|
| 28 |
kwargs.pop("return_unused_kwargs", None)
|
| 29 |
kwargs.pop("config", None)
|
| 30 |
|
| 31 |
-
# Call the
|
| 32 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
# ============================================================
|
|
@@ -74,11 +79,12 @@ class Classifier(nn.Module):
|
|
| 74 |
class EmoAxis(PreTrainedModel):
|
| 75 |
config_class = EmoAxisConfig
|
| 76 |
|
| 77 |
-
# FIX: Accept *args and **kwargs
|
| 78 |
def __init__(self, config, *args, **kwargs):
|
| 79 |
super().__init__(config, *args, **kwargs)
|
| 80 |
|
| 81 |
-
# FIX: Load base encoder from
|
|
|
|
| 82 |
base_encoder = AutoModel.from_config(AutoConfig.from_pretrained(config.base_model_name))
|
| 83 |
|
| 84 |
self.encoder = Encoder(base_encoder)
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
# ============================================================
|
| 14 |
+
# CONFIG CLASS (CRITICAL FIX HERE)
|
| 15 |
# ============================================================
|
| 16 |
class EmoAxisConfig(PretrainedConfig):
|
| 17 |
model_type = "emoaxis"
|
|
|
|
| 21 |
self.num_labels = num_labels
|
| 22 |
self.base_model_name = base_model_name
|
| 23 |
|
| 24 |
+
# 🚨 ULTIMATE FIX FOR THE TYPERROR 🚨
|
| 25 |
+
# We overload this method to manually filter out arguments that cause the clash.
|
| 26 |
@classmethod
|
| 27 |
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 28 |
+
# Filter the arguments that AutoConfig passes internally, causing the 'multiple values' error
|
| 29 |
kwargs.pop("return_unused_kwargs", None)
|
| 30 |
kwargs.pop("config", None)
|
| 31 |
|
| 32 |
+
# Call the base PretrainedConfig's method directly to avoid conflicts with super()
|
| 33 |
+
return PretrainedConfig.from_pretrained(
|
| 34 |
+
cls, # Pass the class itself (EmoAxisConfig) as the first argument
|
| 35 |
+
pretrained_model_name_or_path,
|
| 36 |
+
**kwargs
|
| 37 |
+
)
|
| 38 |
|
| 39 |
|
| 40 |
# ============================================================
|
|
|
|
| 79 |
class EmoAxis(PreTrainedModel):
|
| 80 |
config_class = EmoAxisConfig
|
| 81 |
|
| 82 |
+
# FIX: Accept *args and **kwargs for resilience
|
| 83 |
def __init__(self, config, *args, **kwargs):
|
| 84 |
super().__init__(config, *args, **kwargs)
|
| 85 |
|
| 86 |
+
# FIX: Load base encoder from its CONFIG to ensure fine-tuned weights are used
|
| 87 |
+
# (prevents reinitialization of the base model weights)
|
| 88 |
base_encoder = AutoModel.from_config(AutoConfig.from_pretrained(config.base_model_name))
|
| 89 |
|
| 90 |
self.encoder = Encoder(base_encoder)
|