Subi003 commited on
Commit
0a9ac4d
·
verified ·
1 Parent(s): 28eb6fb

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +9 -69
modeling.py CHANGED
@@ -1,69 +1,20 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
4
 
5
- from transformers import (
6
- PreTrainedModel,
7
- PretrainedConfig,
8
- AutoModel,
9
- AutoConfig,
10
- )
11
-
12
-
13
- # ============================================================
14
- # CONFIG CLASS (NO **kwargs in __init__)
15
- # ============================================================
16
- class EmoAxisConfig(PretrainedConfig):
17
- model_type = "emoaxis"
18
-
19
- def __init__(self, num_labels=28, base_model_name="roberta-base"):
20
- # Explicitly pass model_type to super() if removing **kwargs
21
- super().__init__(model_type=self.model_type, num_labels=num_labels, base_model_name=base_model_name)
22
- self.num_labels = num_labels
23
- self.base_model_name = base_model_name
24
-
25
- # CRITICAL FIX: Overload from_pretrained (must retain **kwargs here to work)
26
- # NOTE: You MUST retain **kwargs in the from_pretrained signature for it to function correctly
27
- @classmethod
28
- def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
29
- kwargs.pop("return_unused_kwargs", None)
30
- kwargs.pop("config", None)
31
-
32
- # Call the base PretrainedConfig's method directly
33
- return PretrainedConfig.from_pretrained(
34
- cls,
35
- pretrained_model_name_or_path,
36
- **kwargs
37
- )
38
-
39
-
40
- # ============================================================
41
- # ENCODER MODULE
42
- # ============================================================
43
  class Encoder(nn.Module):
44
  def __init__(self, base_encoder):
45
  super().__init__()
46
  self.encoder = base_encoder
47
 
48
  def forward(self, inputs):
49
- # NOTE: If the base encoder takes specific non-standard arguments,
50
- # this will break without **kwargs. We assume only input_ids/attention_mask are passed.
51
- outputs = self.encoder(
52
- input_ids=inputs["input_ids"],
53
- attention_mask=inputs["attention_mask"],
54
- output_hidden_states=True
55
- )
56
  last_hidden = outputs.hidden_states[-1]
57
-
58
  mask = inputs["attention_mask"].unsqueeze(-1).float()
59
  pooled = (last_hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
60
-
61
  return F.normalize(pooled, p=2, dim=1)
62
 
63
-
64
- # ============================================================
65
- # CLASSIFIER MODULE
66
- # ============================================================
67
  class Classifier(nn.Module):
68
  def __init__(self, input_dim=768, num_classes=28):
69
  super().__init__()
@@ -78,28 +29,17 @@ class Classifier(nn.Module):
78
  def forward(self, x):
79
  return self.mlp(x)
80
 
81
-
82
- # ============================================================
83
- # MAIN MODEL (NO *args or **kwargs)
84
- # ============================================================
85
  class EmoAxis(PreTrainedModel):
86
- config_class = EmoAxisConfig
87
-
88
- # Removed *args and **kwargs from signature
89
- def __init__(self, config):
90
  super().__init__(config)
91
-
92
- # This line remains correct for loading saved weights
93
- base_encoder = AutoModel.from_config(AutoConfig.from_pretrained(config.base_model_name))
94
-
95
  self.encoder = Encoder(base_encoder)
96
- self.classifier = Classifier(
97
- input_dim=base_encoder.config.hidden_size,
98
- num_classes=config.num_labels
99
- )
100
 
101
- # Removed **kwargs from signature
102
- def forward(self, input_ids=None, attention_mask=None):
103
  inputs = {
104
  "input_ids": input_ids,
105
  "attention_mask": attention_mask,
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, AutoModel, AutoConfig
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class Encoder(nn.Module):
7
  def __init__(self, base_encoder):
8
  super().__init__()
9
  self.encoder = base_encoder
10
 
11
  def forward(self, inputs):
12
+ outputs = self.encoder(**inputs, output_hidden_states=True)
 
 
 
 
 
 
13
  last_hidden = outputs.hidden_states[-1]
 
14
  mask = inputs["attention_mask"].unsqueeze(-1).float()
15
  pooled = (last_hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
 
16
  return F.normalize(pooled, p=2, dim=1)
17
 
 
 
 
 
18
  class Classifier(nn.Module):
19
  def __init__(self, input_dim=768, num_classes=28):
20
  super().__init__()
 
29
  def forward(self, x):
30
  return self.mlp(x)
31
 
 
 
 
 
32
  class EmoAxis(PreTrainedModel):
33
+ config_class = AutoConfig
34
+
35
+ def __init__(self, config):
 
36
  super().__init__(config)
37
+ base_encoder = AutoModel.from_config(config) # IMPORTANT: use from_config
 
 
 
38
  self.encoder = Encoder(base_encoder)
39
+ self.classifier = Classifier(input_dim=base_encoder.config.hidden_size,num_classes=config.num_labels)
 
 
 
40
 
41
+
42
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
43
  inputs = {
44
  "input_ids": input_ids,
45
  "attention_mask": attention_mask,