Subi003 commited on
Commit
28eb6fb
·
verified ·
1 Parent(s): db0b1e7

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +22 -17
modeling.py CHANGED
@@ -11,27 +11,27 @@ from transformers import (
11
 
12
 
13
  # ============================================================
14
- # CONFIG CLASS (CRITICAL FIX HERE)
15
  # ============================================================
16
  class EmoAxisConfig(PretrainedConfig):
17
  model_type = "emoaxis"
18
 
19
- def __init__(self, num_labels=28, base_model_name="roberta-base", **kwargs):
20
- super().__init__(**kwargs)
 
21
  self.num_labels = num_labels
22
  self.base_model_name = base_model_name
23
 
24
- # 🚨 ULTIMATE FIX FOR THE TYPERROR 🚨
25
- # We overload this method to manually filter out arguments that cause the clash.
26
  @classmethod
27
  def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
28
- # Filter the arguments that AutoConfig passes internally, causing the 'multiple values' error
29
  kwargs.pop("return_unused_kwargs", None)
30
  kwargs.pop("config", None)
31
 
32
- # Call the base PretrainedConfig's method directly to avoid conflicts with super()
33
  return PretrainedConfig.from_pretrained(
34
- cls, # Pass the class itself (EmoAxisConfig) as the first argument
35
  pretrained_model_name_or_path,
36
  **kwargs
37
  )
@@ -46,7 +46,13 @@ class Encoder(nn.Module):
46
  self.encoder = base_encoder
47
 
48
  def forward(self, inputs):
49
- outputs = self.encoder(**inputs, output_hidden_states=True)
 
 
 
 
 
 
50
  last_hidden = outputs.hidden_states[-1]
51
 
52
  mask = inputs["attention_mask"].unsqueeze(-1).float()
@@ -74,17 +80,16 @@ class Classifier(nn.Module):
74
 
75
 
76
  # ============================================================
77
- # MAIN MODEL
78
  # ============================================================
79
  class EmoAxis(PreTrainedModel):
80
  config_class = EmoAxisConfig
81
 
82
- # FIX: Accept *args and **kwargs for resilience
83
- def __init__(self, config, *args, **kwargs):
84
- super().__init__(config, *args, **kwargs)
85
 
86
- # FIX: Load base encoder from its CONFIG to ensure fine-tuned weights are used
87
- # (prevents reinitialization of the base model weights)
88
  base_encoder = AutoModel.from_config(AutoConfig.from_pretrained(config.base_model_name))
89
 
90
  self.encoder = Encoder(base_encoder)
@@ -93,11 +98,11 @@ class EmoAxis(PreTrainedModel):
93
  num_classes=config.num_labels
94
  )
95
 
96
- def forward(self, input_ids=None, attention_mask=None, **kwargs):
 
97
  inputs = {
98
  "input_ids": input_ids,
99
  "attention_mask": attention_mask,
100
- **kwargs
101
  }
102
  pooled = self.encoder(inputs)
103
  logits = self.classifier(pooled)
 
11
 
12
 
13
  # ============================================================
14
+ # CONFIG CLASS (NO **kwargs in __init__)
15
  # ============================================================
16
  class EmoAxisConfig(PretrainedConfig):
17
  model_type = "emoaxis"
18
 
19
+ def __init__(self, num_labels=28, base_model_name="roberta-base"):
20
+ # Explicitly pass model_type to super() if removing **kwargs
21
+ super().__init__(model_type=self.model_type, num_labels=num_labels, base_model_name=base_model_name)
22
  self.num_labels = num_labels
23
  self.base_model_name = base_model_name
24
 
25
+ # CRITICAL FIX: Overload from_pretrained (must retain **kwargs here to work)
26
+ # NOTE: You MUST retain **kwargs in the from_pretrained signature for it to function correctly
27
  @classmethod
28
  def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
 
29
  kwargs.pop("return_unused_kwargs", None)
30
  kwargs.pop("config", None)
31
 
32
+ # Call the base PretrainedConfig's method directly
33
  return PretrainedConfig.from_pretrained(
34
+ cls,
35
  pretrained_model_name_or_path,
36
  **kwargs
37
  )
 
46
  self.encoder = base_encoder
47
 
48
  def forward(self, inputs):
49
+ # NOTE: If the base encoder takes specific non-standard arguments,
50
+ # this will break without **kwargs. We assume only input_ids/attention_mask are passed.
51
+ outputs = self.encoder(
52
+ input_ids=inputs["input_ids"],
53
+ attention_mask=inputs["attention_mask"],
54
+ output_hidden_states=True
55
+ )
56
  last_hidden = outputs.hidden_states[-1]
57
 
58
  mask = inputs["attention_mask"].unsqueeze(-1).float()
 
80
 
81
 
82
  # ============================================================
83
+ # MAIN MODEL (NO *args or **kwargs)
84
  # ============================================================
85
  class EmoAxis(PreTrainedModel):
86
  config_class = EmoAxisConfig
87
 
88
+ # Removed *args and **kwargs from signature
89
+ def __init__(self, config):
90
+ super().__init__(config)
91
 
92
+ # This line remains correct for loading saved weights
 
93
  base_encoder = AutoModel.from_config(AutoConfig.from_pretrained(config.base_model_name))
94
 
95
  self.encoder = Encoder(base_encoder)
 
98
  num_classes=config.num_labels
99
  )
100
 
101
+ # Removed **kwargs from signature
102
+ def forward(self, input_ids=None, attention_mask=None):
103
  inputs = {
104
  "input_ids": input_ids,
105
  "attention_mask": attention_mask,
 
106
  }
107
  pooled = self.encoder(inputs)
108
  logits = self.classifier(pooled)