Subi003 commited on
Commit
0f8e09d
·
verified ·
1 Parent(s): 0bdac00

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +13 -1
modeling.py CHANGED
@@ -21,6 +21,16 @@ class EmoAxisConfig(PretrainedConfig):
21
  self.num_labels = num_labels
22
  self.base_model_name = base_model_name
23
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # ============================================================
26
  # ENCODER MODULE
@@ -64,9 +74,11 @@ class Classifier(nn.Module):
64
  class EmoAxis(PreTrainedModel):
65
  config_class = EmoAxisConfig
66
 
 
67
  def __init__(self, config, *args, **kwargs):
68
  super().__init__(config, *args, **kwargs)
69
 
 
70
  base_encoder = AutoModel.from_config(AutoConfig.from_pretrained(config.base_model_name))
71
 
72
  self.encoder = Encoder(base_encoder)
@@ -91,4 +103,4 @@ class EmoAxis(PreTrainedModel):
91
  # REGISTER WITH HUGGINGFACE AUTOCLASSES
92
  # ============================================================
93
  AutoConfig.register("emoaxis", EmoAxisConfig)
94
- AutoModel.register(EmoAxisConfig, EmoAxis)
 
21
  self.num_labels = num_labels
22
  self.base_model_name = base_model_name
23
 
24
+ # 🚨 CRITICAL FIX: Overload from_pretrained to prevent argument clashes.
25
+ @classmethod
26
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
27
+ # Filter arguments that cause TypeErrors in nested calls
28
+ kwargs.pop("return_unused_kwargs", None)
29
+ kwargs.pop("config", None)
30
+
31
+ # Call the parent method with cleaned arguments
32
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
33
+
34
 
35
  # ============================================================
36
  # ENCODER MODULE
 
74
  class EmoAxis(PreTrainedModel):
75
  config_class = EmoAxisConfig
76
 
77
+ # FIX: Accept *args and **kwargs to avoid initialization errors.
78
  def __init__(self, config, *args, **kwargs):
79
  super().__init__(config, *args, **kwargs)
80
 
81
+ # FIX: Load base encoder from config to ensure fine-tuned weights are used.
82
  base_encoder = AutoModel.from_config(AutoConfig.from_pretrained(config.base_model_name))
83
 
84
  self.encoder = Encoder(base_encoder)
 
103
  # REGISTER WITH HUGGINGFACE AUTOCLASSES
104
  # ============================================================
105
  AutoConfig.register("emoaxis", EmoAxisConfig)
106
+ AutoModel.register(EmoAxisConfig, EmoAxis)