Subi003 commited on
Commit
db0b1e7
·
verified ·
1 Parent(s): 0f8e09d

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +13 -7
modeling.py CHANGED
@@ -11,7 +11,7 @@ from transformers import (
11
 
12
 
13
  # ============================================================
14
- # CONFIG CLASS
15
  # ============================================================
16
  class EmoAxisConfig(PretrainedConfig):
17
  model_type = "emoaxis"
@@ -21,15 +21,20 @@ class EmoAxisConfig(PretrainedConfig):
21
  self.num_labels = num_labels
22
  self.base_model_name = base_model_name
23
 
24
- # 🚨 CRITICAL FIX: Overload from_pretrained to prevent argument clashes.
 
25
  @classmethod
26
  def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
27
- # Filter arguments that cause TypeErrors in nested calls
28
  kwargs.pop("return_unused_kwargs", None)
29
  kwargs.pop("config", None)
30
 
31
- # Call the parent method with cleaned arguments
32
- return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
 
 
 
 
33
 
34
 
35
  # ============================================================
@@ -74,11 +79,12 @@ class Classifier(nn.Module):
74
  class EmoAxis(PreTrainedModel):
75
  config_class = EmoAxisConfig
76
 
77
- # FIX: Accept *args and **kwargs to avoid initialization errors.
78
  def __init__(self, config, *args, **kwargs):
79
  super().__init__(config, *args, **kwargs)
80
 
81
- # FIX: Load base encoder from config to ensure fine-tuned weights are used.
 
82
  base_encoder = AutoModel.from_config(AutoConfig.from_pretrained(config.base_model_name))
83
 
84
  self.encoder = Encoder(base_encoder)
 
11
 
12
 
13
  # ============================================================
14
+ # CONFIG CLASS (CRITICAL FIX HERE)
15
  # ============================================================
16
  class EmoAxisConfig(PretrainedConfig):
17
  model_type = "emoaxis"
 
21
  self.num_labels = num_labels
22
  self.base_model_name = base_model_name
23
 
24
+ # 🚨 ULTIMATE FIX FOR THE TYPERROR 🚨
25
+ # We overload this method to manually filter out arguments that cause the clash.
26
  @classmethod
27
  def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
28
+ # Filter the arguments that AutoConfig passes internally, causing the 'multiple values' error
29
  kwargs.pop("return_unused_kwargs", None)
30
  kwargs.pop("config", None)
31
 
32
+ # Call the base PretrainedConfig's method directly to avoid conflicts with super()
33
+ return PretrainedConfig.from_pretrained(
34
+ cls, # Pass the class itself (EmoAxisConfig) as the first argument
35
+ pretrained_model_name_or_path,
36
+ **kwargs
37
+ )
38
 
39
 
40
  # ============================================================
 
79
  class EmoAxis(PreTrainedModel):
80
  config_class = EmoAxisConfig
81
 
82
+ # FIX: Accept *args and **kwargs for resilience
83
  def __init__(self, config, *args, **kwargs):
84
  super().__init__(config, *args, **kwargs)
85
 
86
+ # FIX: Load base encoder from its CONFIG to ensure fine-tuned weights are used
87
+ # (prevents reinitialization of the base model weights)
88
  base_encoder = AutoModel.from_config(AutoConfig.from_pretrained(config.base_model_name))
89
 
90
  self.encoder = Encoder(base_encoder)