AbstractPhil commited on
Commit
7f7daa9
·
verified ·
1 Parent(s): 2028a79

added output_hidden_states to allow hidden state output for the model

Browse files
Files changed (1) hide show
  1. modeling_caption_bert.py +14 -4
modeling_caption_bert.py CHANGED
@@ -95,7 +95,8 @@ class CaptionBertModel(PreTrainedModel):
95
 
96
  self.post_init()
97
 
98
- def forward(self, input_ids=None, attention_mask=None, **kwargs):
 
99
  B, L = input_ids.shape
100
  device = input_ids.device
101
 
@@ -110,7 +111,12 @@ class CaptionBertModel(PreTrainedModel):
110
  else:
111
  key_padding_mask = (input_ids == self.config.pad_token_id)
112
 
113
- x = self.encoder(x, src_key_padding_mask=key_padding_mask)
 
 
 
 
 
114
 
115
  # Mean pool over non-padding tokens
116
  if attention_mask is not None:
@@ -123,10 +129,14 @@ class CaptionBertModel(PreTrainedModel):
123
  embedding = F.normalize(self.output_proj(pooled), dim=-1)
124
 
125
  # Return in HuggingFace-compatible format
126
- return type('Output', (), {
127
  'last_hidden_state': embedding,
128
  'pooler_output': embedding,
129
- })()
 
 
 
 
130
 
131
  def encode(self, texts, tokenizer=None, max_length=512, batch_size=128,
132
  device=None):
 
95
 
96
  self.post_init()
97
 
98
+ def forward(self, input_ids=None, attention_mask=None,
99
+ output_hidden_states=False, **kwargs):
100
  B, L = input_ids.shape
101
  device = input_ids.device
102
 
 
111
  else:
112
  key_padding_mask = (input_ids == self.config.pad_token_id)
113
 
114
+ # Layer-by-layer for hidden state capture
115
+ hidden_states = [x] if output_hidden_states else None
116
+ for layer in self.encoder.layers:
117
+ x = layer(x, src_key_padding_mask=key_padding_mask)
118
+ if output_hidden_states:
119
+ hidden_states.append(x)
120
 
121
  # Mean pool over non-padding tokens
122
  if attention_mask is not None:
 
129
  embedding = F.normalize(self.output_proj(pooled), dim=-1)
130
 
131
  # Return in HuggingFace-compatible format
132
+ result = {
133
  'last_hidden_state': embedding,
134
  'pooler_output': embedding,
135
+ }
136
+ if output_hidden_states:
137
+ result['hidden_states'] = tuple(hidden_states)
138
+
139
+ return type('Output', (), result)()
140
 
141
  def encode(self, texts, tokenizer=None, max_length=512, batch_size=128,
142
  device=None):