nixie1981 commited on
Commit
36988a2
·
verified ·
1 Parent(s): 7f84c5d

Upload modeling_conceptframemet.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_conceptframemet.py +282 -296
modeling_conceptframemet.py CHANGED
@@ -1,102 +1,96 @@
1
  """
2
- ConceptFrameMet: Metaphor Detection with Frame and Source Domain Prediction
3
 
4
- This model detects metaphors and predicts their semantic frames and source domains.
5
- Based on AdaptiveSourceQAMelBert architecture.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
 
8
  import torch
9
  import torch.nn as nn
10
- from transformers import RobertaModel, RobertaTokenizer, AutoModelForQuestionAnswering, AutoTokenizer
11
- from typing import Dict, List, Tuple, Optional
12
- import json
13
- import os
14
 
15
 
16
- class ConceptFrameMetForMetaphorDetection(nn.Module):
17
- """
18
- Metaphor detection model with semantic frame and source domain prediction capabilities.
19
-
20
- This model:
21
- - Detects metaphors in text
22
- - Predicts semantic frames for target words
23
- - Predicts source domains for metaphors
24
- """
25
-
26
- def __init__(
27
- self,
28
- encoder_model_name="roberta-base",
29
- frame_qa_model_name="nixie1981/sem_frames",
30
- source_qa_model_name=None,
31
- classifier_hidden=768,
32
- drop_ratio=0.2,
33
- num_labels=2,
34
- source_blend_mode='replacement',
35
- source_use_mode='metaphor_only',
36
- source_alpha=0.3,
37
- metaphor_threshold=0.5,
38
- ):
39
- super().__init__()
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  self.num_labels = num_labels
42
- self.classifier_hidden = classifier_hidden
43
- self.drop_ratio = drop_ratio
44
-
45
- # Configuration
46
- self.source_blend_mode = source_blend_mode
47
- self.source_use_mode = source_use_mode
48
- self.source_alpha = source_alpha
49
- self.metaphor_threshold = metaphor_threshold
50
-
51
- # Load encoder (RoBERTa) with correct type_vocab_size
52
- from transformers import RobertaConfig
53
-
54
- # Load base model first
55
- self.encoder = RobertaModel.from_pretrained(encoder_model_name)
56
-
57
- # Resize token_type_embeddings to match training (type_vocab_size=4)
58
- # This is needed because the model was trained with 4 token types
59
- if self.encoder.embeddings.token_type_embeddings.weight.shape[0] != 4:
60
- old_embeddings = self.encoder.embeddings.token_type_embeddings
61
- new_embeddings = nn.Embedding(4, old_embeddings.embedding_dim)
62
- # Copy the original embedding (for type 0)
63
- new_embeddings.weight.data[0] = old_embeddings.weight.data[0]
64
- # Initialize the rest
65
- new_embeddings.weight.data[1:].normal_(mean=0.0, std=self.encoder.config.initializer_range)
66
- self.encoder.embeddings.token_type_embeddings = new_embeddings
67
- self.encoder.config.type_vocab_size = 4
68
-
69
- self.tokenizer = RobertaTokenizer.from_pretrained(encoder_model_name)
70
- self.config = self.encoder.config
71
-
72
- # Load frame QA model
73
- try:
74
- self.frame_qa_model = AutoModelForQuestionAnswering.from_pretrained(frame_qa_model_name)
75
- self.frame_qa_tokenizer = AutoTokenizer.from_pretrained(frame_qa_model_name)
76
- self.has_frame_predictor = True
77
- except:
78
- print("Warning: Frame QA model not available")
79
- self.has_frame_predictor = False
80
-
81
- # Load source QA model (if available)
82
- if source_qa_model_name:
83
- try:
84
- self.source_qa_model = AutoModelForQuestionAnswering.from_pretrained(source_qa_model_name)
85
- self.source_qa_tokenizer = AutoTokenizer.from_pretrained(source_qa_model_name)
86
- self.has_source_predictor = True
87
- except:
88
- print("Warning: Source QA model not available")
89
- self.has_source_predictor = False
90
  else:
91
- self.has_source_predictor = False
92
-
93
- # Dropout
94
- self.dropout = nn.Dropout(drop_ratio)
95
-
96
- # Classification layers
97
- self.SPV_linear = nn.Linear(self.config.hidden_size * 2, classifier_hidden)
98
- self.MIP_linear = nn.Linear(self.config.hidden_size * 2, classifier_hidden)
99
- self.classifier = nn.Linear(classifier_hidden * 2, num_labels)
 
 
 
 
 
 
 
 
 
100
 
101
  self._init_weights(self.SPV_linear)
102
  self._init_weights(self.MIP_linear)
@@ -104,245 +98,237 @@ class ConceptFrameMetForMetaphorDetection(nn.Module):
104
 
105
  self.logsoftmax = nn.LogSoftmax(dim=1)
106
 
107
- # Load source and frame labels
108
- self.source_id2label = {}
109
- self.frame_id2label = {}
110
-
 
 
 
 
 
 
 
 
 
 
 
111
  def _init_weights(self, module):
112
  """Initialize the weights"""
113
  if isinstance(module, (nn.Linear, nn.Embedding)):
114
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
 
 
 
115
  if isinstance(module, nn.Linear) and module.bias is not None:
116
  module.bias.data.zero_()
117
-
118
- def predict_frames(self, sentence: str, target_word: str) -> Dict[str, any]:
 
119
  """
120
- Predict semantic frame for a target word in context
121
 
122
- Args:
123
- sentence: Input sentence
124
- target_word: Target word to analyze
125
-
126
  Returns:
127
- Dictionary with frame prediction and confidence
 
 
128
  """
129
- if not self.has_frame_predictor:
130
- return {"frame": "UNKNOWN", "confidence": 0.0}
131
 
132
- try:
133
- inputs = self.frame_qa_tokenizer(
134
- sentence,
135
- target_word,
136
- max_length=150,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  padding='max_length',
138
  truncation=True,
139
  return_tensors='pt'
140
  )
 
141
 
142
- with torch.no_grad():
143
- outputs = self.frame_qa_model(**inputs)
144
-
145
- # Check if it has start/end logits
146
- if hasattr(outputs, 'start_logits') and hasattr(outputs, 'end_logits'):
147
- start_logits = outputs.start_logits
148
- end_logits = outputs.end_logits
149
-
150
- start_idx = torch.argmax(start_logits)
151
- end_idx = torch.argmax(end_logits)
152
-
153
- confidence = (torch.max(torch.softmax(start_logits, dim=-1)) +
154
- torch.max(torch.softmax(end_logits, dim=-1))) / 2.0
155
-
156
- frame_tokens = inputs['input_ids'][0][start_idx:end_idx+1]
157
- frame = self.frame_qa_tokenizer.decode(frame_tokens, skip_special_tokens=True)
158
- else:
159
- # Fallback if model structure is different
160
- frame = "Self_motion"
161
- confidence = 0.5
162
-
163
- return {
164
- "frame": frame if frame else "UNKNOWN",
165
- "confidence": confidence.item() if isinstance(confidence, torch.Tensor) else confidence
166
- }
167
- except Exception as e:
168
- # If frame prediction fails, return a default
169
- print(f"Frame prediction warning: {e}")
170
- return {"frame": "UNKNOWN", "confidence": 0.0}
171
-
172
- def predict_source(self, sentence: str, target_word: str) -> Dict[str, any]:
173
- """
174
- Predict source domain for a metaphor
175
-
176
- Args:
177
- sentence: Input sentence
178
- target_word: Target word to analyze
179
-
180
- Returns:
181
- Dictionary with source prediction and confidence
182
- """
183
- if not self.has_source_predictor:
184
- return {"source": "UNKNOWN", "confidence": 0.0}
185
-
186
- inputs = self.source_qa_tokenizer(
187
- sentence,
188
- target_word,
189
- max_length=150,
190
- padding='max_length',
191
- truncation=True,
192
- return_tensors='pt'
193
- )
194
-
195
  with torch.no_grad():
196
- outputs = self.source_qa_model(**inputs)
197
- logits = outputs.logits if hasattr(outputs, 'logits') else outputs.start_logits
198
-
199
- probs = torch.softmax(logits, dim=-1)
200
- predicted_id = torch.argmax(probs, dim=-1)
201
- confidence = probs.gather(-1, predicted_id.unsqueeze(-1)).squeeze(-1)
202
-
203
- source = self.source_id2label.get(predicted_id.item(), "UNKNOWN")
204
-
205
- return {
206
- "source": source,
207
- "confidence": confidence.item()
208
- }
209
-
210
- def predict_metaphor(
211
- self,
212
- sentence: str,
213
- target_word: str,
214
- target_positions: Optional[List[int]] = None
215
- ) -> Dict[str, any]:
216
- """
217
- Predict if target word is metaphorical in context
218
 
219
- Args:
220
- sentence: Input sentence
221
- target_word: Target word to analyze
222
- target_positions: Token positions of target word (optional)
223
-
224
- Returns:
225
- Dictionary with metaphor prediction, frame, and source
226
- """
227
- # Tokenize input
228
- inputs = self.tokenizer(
229
- sentence,
230
- max_length=150,
231
  padding='max_length',
232
  truncation=True,
233
  return_tensors='pt'
234
  )
 
 
235
 
236
- # Create target mask
237
- if target_positions is None:
238
- # Find target word positions
239
- target_tokens = self.tokenizer.tokenize(target_word)
240
- sentence_tokens = self.tokenizer.tokenize(sentence)
241
- target_positions = []
242
- for i in range(len(sentence_tokens) - len(target_tokens) + 1):
243
- if sentence_tokens[i:i+len(target_tokens)] == target_tokens:
244
- target_positions = list(range(i+1, i+1+len(target_tokens))) # +1 for CLS token
245
- break
246
-
247
- target_mask = torch.zeros_like(inputs['input_ids'], dtype=torch.float)
248
- if target_positions:
249
- for pos in target_positions:
250
- if pos < target_mask.size(1):
251
- target_mask[0, pos] = 1.0
252
-
253
- # Forward pass for metaphor detection
254
- with torch.no_grad():
255
- outputs = self.encoder(**inputs)
256
- sequence_output = outputs[0]
257
- pooled_output = outputs[1]
258
-
259
- # Get target output
260
- target_output = sequence_output * target_mask.unsqueeze(2)
261
- target_output = target_output.sum(dim=1) / (target_mask.sum(-1, keepdim=True) + 1e-10)
262
- target_output = self.dropout(target_output)
263
- pooled_output = self.dropout(pooled_output)
264
-
265
- # SPV and MIP
266
- SPV_hidden = self.SPV_linear(torch.cat([pooled_output, target_output], dim=1))
267
- MIP_hidden = self.MIP_linear(torch.cat([target_output, target_output], dim=1))
268
-
269
- # Classification
270
- logits = self.classifier(torch.cat([SPV_hidden, MIP_hidden], dim=1))
271
- logits = self.logsoftmax(logits)
272
- probs = torch.exp(logits)
273
-
274
- is_metaphor = torch.argmax(probs, dim=1).item() == 1
275
- metaphor_confidence = probs[0, 1].item()
276
-
277
- # Predict frame and source
278
- frame_result = self.predict_frames(sentence, target_word)
279
- source_result = self.predict_source(sentence, target_word) if is_metaphor else {"source": "N/A", "confidence": 0.0}
280
 
281
- return {
282
- "is_metaphor": is_metaphor,
283
- "metaphor_confidence": metaphor_confidence,
284
- "frame": frame_result["frame"],
285
- "frame_confidence": frame_result["confidence"],
286
- "source": source_result["source"],
287
- "source_confidence": source_result["confidence"]
288
- }
289
-
290
- @classmethod
291
- def from_pretrained(cls, model_path, **kwargs):
292
- """Load model from pretrained checkpoint"""
293
- # Load weights first to check what's in checkpoint
294
- weights_path = os.path.join(model_path, "pytorch_model.bin")
295
- state_dict = torch.load(weights_path, map_location='cpu')
296
 
297
- # Check what's in the checkpoint
298
- has_source_in_checkpoint = any(k.startswith('source_qa_model.') for k in state_dict.keys())
299
- has_frame_in_checkpoint = any(k.startswith('frame_qa_model.') for k in state_dict.keys())
 
300
 
301
- # Initialize model:
302
- # - Download frame_qa_model (nixie1981/sem_frames) - NOT in checkpoint
303
- # - Don't download source_qa_model - IS in checkpoint
304
- model = cls(
305
- frame_qa_model_name="nixie1981/sem_frames", # Download - needed for frames!
306
- source_qa_model_name=None, # Don't download - in checkpoint
307
- **kwargs
308
  )
309
 
310
- # Manually set source flag since weights are in checkpoint
311
- if has_source_in_checkpoint:
312
- model.has_source_predictor = True
313
-
314
- # Load ALL weights from checkpoint (including source_qa_model)
315
- missing, unexpected = model.load_state_dict(state_dict, strict=False)
316
 
317
- print(f"Loaded {len(state_dict)} weights from checkpoint")
318
- if missing:
319
- print(f"Missing {len(missing)} keys")
 
320
 
321
- return model
322
-
323
- def save_pretrained(self, save_directory):
324
- """Save model to directory"""
325
- os.makedirs(save_directory, exist_ok=True)
326
 
327
- # Save weights
328
- torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
 
 
 
 
 
 
 
329
 
330
- # Save config
331
- config = {
332
- "_name_or_path": "ConceptFrameMet",
333
- "architectures": ["ConceptFrameMetForMetaphorDetection"],
334
- "model_type": "conceptframemet",
335
- "num_labels": self.num_labels,
336
- "classifier_hidden": self.classifier_hidden,
337
- "drop_ratio": self.drop_ratio,
338
- "source_blend_mode": self.source_blend_mode,
339
- "source_use_mode": self.source_use_mode,
340
- "source_alpha": self.source_alpha,
341
- "metaphor_threshold": self.metaphor_threshold,
342
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
- with open(os.path.join(save_directory, "config.json"), 'w') as f:
345
- json.dump(config, f, indent=2)
 
 
 
 
 
 
 
346
 
347
- # Save tokenizer
348
- self.tokenizer.save_pretrained(save_directory)
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Adaptive Source QA MelBERT with Configurable Blending Strategies
3
 
4
+ This model provides configurable approaches to incorporating source domain information:
5
+
6
+ FLAGS:
7
+ 1. --source_blend_mode: 'additive' or 'replacement' (default: 'replacement')
8
+ - additive: enhanced = target + alpha * source (keeps target strength)
9
+ - replacement: blended = conf * source + (1-conf) * target (original approach)
10
+
11
+ 2. --source_use_mode: 'metaphor_only' or 'all' (default: 'all')
12
+ - metaphor_only: Only use source for samples with high metaphor probability
13
+ - all: Use source for all samples
14
+
15
+ 3. --source_alpha: float (default: 0.3) - scaling factor for additive mode
16
+
17
+ 4. --metaphor_threshold: float (default: 0.5) - threshold for metaphor-only mode
18
+
19
+ Architecture:
20
+ - CONTEXT: target_word in full sentence → encoder 1 → target_context_embedding
21
+ - SOURCE: [SEP] sentence [SEP] target [SEP] → QA model → predict source + confidence
22
+ - ISOLATED: isolated target → encoder 2 → target_embedding
23
+ - BLEND: Configurable (additive or replacement)
24
+ - FILTER: Configurable (metaphor-only or all)
25
+ - MIP: [enhanced_embedding, target_context_embedding]
26
+ - SPV: [pooled, enhanced_embedding] or [pooled, target_context_embedding]
27
  """
28
 
29
  import torch
30
  import torch.nn as nn
31
+ import torch.nn.functional as F
 
 
 
32
 
33
 
34
+ class AdaptiveSourceQAMelBert(nn.Module):
35
+ """MelBERT with configurable source domain blending strategies"""
36
+
37
+ def __init__(self, args, Model, config, Source_QA_Model,
38
+ source_qa_tokenizer, melbert_tokenizer, num_labels=2):
39
+ """
40
+ Initialize the model with configurable flags
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ Args:
43
+ args: Configuration arguments with:
44
+ - source_blend_mode: 'additive' or 'replacement'
45
+ - source_use_mode: 'metaphor_only' or 'all'
46
+ - source_alpha: scaling factor for additive mode
47
+ - metaphor_threshold: threshold for metaphor-only mode
48
+ Model: MelBert encoder (RoBERTa/BERT)
49
+ config: Model configuration
50
+ Source_QA_Model: QA-style model to predict source domain
51
+ source_qa_tokenizer: Tokenizer for QA model
52
+ melbert_tokenizer: Tokenizer for MelBert
53
+ num_labels: Number of metaphor classes (2: literal/metaphorical)
54
+ """
55
+ super(AdaptiveSourceQAMelBert, self).__init__()
56
  self.num_labels = num_labels
57
+ self.encoder = Model
58
+ self.source_qa_model = Source_QA_Model
59
+ self.source_qa_tokenizer = source_qa_tokenizer
60
+ self.melbert_tokenizer = melbert_tokenizer
61
+ self.config = config
62
+ self.dropout = nn.Dropout(args.drop_ratio)
63
+ self.args = args
64
+
65
+ # Configuration flags with defaults
66
+ self.source_blend_mode = getattr(args, 'source_blend_mode', 'replacement')
67
+ self.source_use_mode = getattr(args, 'source_use_mode', 'all')
68
+ self.source_alpha = getattr(args, 'source_alpha', 0.3)
69
+ self.metaphor_threshold = getattr(args, 'metaphor_threshold', 0.5)
70
+
71
+ # Freeze or unfreeze source QA model
72
+ if not getattr(args, 'unfreeze_source_qa', False):
73
+ for param in self.source_qa_model.parameters():
74
+ param.requires_grad = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  else:
76
+ for param in self.source_qa_model.parameters():
77
+ param.requires_grad = True
78
+
79
+ # Load source labels
80
+ self.source_id2label = {}
81
+ try:
82
+ import json
83
+ with open('source_finder/source_labels.json', 'r') as f:
84
+ source_label2id = json.load(f)
85
+ self.source_id2label = {v: k for k, v in source_label2id.items()}
86
+ print(f"✓ Loaded {len(self.source_id2label)} source domain labels")
87
+ except Exception as e:
88
+ print(f"❌ Warning: Could not load source labels: {e}")
89
+
90
+ # SPV and MIP linear layers
91
+ self.SPV_linear = nn.Linear(config.hidden_size * 2, args.classifier_hidden)
92
+ self.MIP_linear = nn.Linear(config.hidden_size * 2, args.classifier_hidden)
93
+ self.classifier = nn.Linear(args.classifier_hidden * 2, num_labels)
94
 
95
  self._init_weights(self.SPV_linear)
96
  self._init_weights(self.MIP_linear)
 
98
 
99
  self.logsoftmax = nn.LogSoftmax(dim=1)
100
 
101
+ # Print configuration
102
+ print(f"\n{'='*80}")
103
+ print(f"✓ AdaptiveSourceQAMelBert initialized")
104
+ print(f" - Blend Mode: {self.source_blend_mode.upper()}")
105
+ if self.source_blend_mode == 'additive':
106
+ print(f" → enhanced = target + {self.source_alpha} * source")
107
+ else:
108
+ print(f" → blended = conf * source + (1-conf) * target")
109
+ print(f" - Use Mode: {self.source_use_mode.upper()}")
110
+ if self.source_use_mode == 'metaphor_only':
111
+ print(f" → Only use source when metaphor_score > {self.metaphor_threshold}")
112
+ else:
113
+ print(f" → Use source for all samples")
114
+ print(f"{'='*80}\n")
115
+
116
  def _init_weights(self, module):
117
  """Initialize the weights"""
118
  if isinstance(module, (nn.Linear, nn.Embedding)):
119
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
120
+ elif isinstance(module, nn.LayerNorm):
121
+ module.bias.data.zero_()
122
+ module.weight.data.fill_(1.0)
123
  if isinstance(module, nn.Linear) and module.bias is not None:
124
  module.bias.data.zero_()
125
+
126
+ def predict_source_and_embeddings(self, input_ids, target_mask, attention_mask,
127
+ input_ids_2, target_mask_2, attention_mask_2):
128
  """
129
+ Predict source domain and get source/target embeddings
130
 
 
 
 
 
131
  Returns:
132
+ source_embeddings: [batch_size, hidden_size]
133
+ target_embeddings: [batch_size, hidden_size]
134
+ confidences: [batch_size] - confidence scores
135
  """
136
+ batch_size = input_ids.size(0)
 
137
 
138
+ # 1. Decode sentences and extract target words
139
+ sentences = []
140
+ target_words = []
141
+
142
+ for i in range(batch_size):
143
+ sentence = self.melbert_tokenizer.decode(input_ids[i], skip_special_tokens=True)
144
+ target_positions = target_mask[i].nonzero(as_tuple=True)[0]
145
+
146
+ if len(target_positions) > 0:
147
+ target_tokens = input_ids[i][target_positions]
148
+ target_word = self.melbert_tokenizer.decode(target_tokens, skip_special_tokens=True)
149
+ else:
150
+ target_word = "unknown"
151
+
152
+ sentences.append(sentence)
153
+ target_words.append(target_word)
154
+
155
+ # 2. Format QA input and predict source
156
+ with torch.no_grad():
157
+ qa_inputs = self.source_qa_tokenizer(
158
+ sentences,
159
+ target_words,
160
+ max_length=self.args.max_seq_length,
161
  padding='max_length',
162
  truncation=True,
163
  return_tensors='pt'
164
  )
165
+ qa_inputs = {k: v.to(input_ids.device) for k, v in qa_inputs.items()}
166
 
167
+ # If source model is FrameAwareSourcePredictor, also pass frame inputs
168
+ # (frame inputs are the same as source inputs for this use case)
169
+ if hasattr(self.source_qa_model, 'frame_finder'):
170
+ qa_inputs['frame_input_ids'] = qa_inputs['input_ids']
171
+ qa_inputs['frame_attention_mask'] = qa_inputs['attention_mask']
172
+
173
+ # 3. Get source predictions with confidence
174
+ qa_outputs = self.source_qa_model(**qa_inputs)
175
+ source_logits = qa_outputs.logits
176
+ source_probs = torch.softmax(source_logits, dim=-1)
177
+ predicted_source_ids = torch.argmax(source_logits, dim=-1)
178
+
179
+ # Get confidence scores
180
+ confidences = source_probs.gather(1, predicted_source_ids.unsqueeze(1)).squeeze(1)
181
+
182
+ # Map to source words
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  with torch.no_grad():
184
+ predicted_sources = [self.source_id2label.get(sid.item(), "UNKNOWN")
185
+ for sid in predicted_source_ids]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ # 4. Encode predicted source words
188
+ source_inputs = self.melbert_tokenizer(
189
+ predicted_sources,
190
+ max_length=self.args.max_seq_length,
 
 
 
 
 
 
 
 
191
  padding='max_length',
192
  truncation=True,
193
  return_tensors='pt'
194
  )
195
+ source_inputs = {k: v.to(input_ids.device) for k, v in source_inputs.items()}
196
+ source_target_mask = (source_inputs['input_ids'] != self.melbert_tokenizer.pad_token_id).float()
197
 
198
+ source_outputs = self.encoder(
199
+ source_inputs['input_ids'],
200
+ attention_mask=source_inputs['attention_mask']
201
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ source_sequence_output = source_outputs[0]
204
+ source_target_output = source_sequence_output * source_target_mask.unsqueeze(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ if self.args.small_mean:
207
+ source_embeddings = source_target_output.mean(1)
208
+ else:
209
+ source_embeddings = source_target_output.sum(dim=1) / source_target_mask.sum(-1, keepdim=True)
210
 
211
+ # 5. Encode original isolated target words
212
+ target_outputs_2 = self.encoder(
213
+ input_ids_2,
214
+ attention_mask=attention_mask_2
 
 
 
215
  )
216
 
217
+ target_sequence_output_2 = target_outputs_2[0]
218
+ target_output_2 = target_sequence_output_2 * target_mask_2.unsqueeze(2)
 
 
 
 
219
 
220
+ if self.args.small_mean:
221
+ target_embeddings_2 = target_output_2.mean(1)
222
+ else:
223
+ target_embeddings_2 = target_output_2.sum(dim=1) / target_mask_2.sum(-1, keepdim=True)
224
 
225
+ return source_embeddings, target_embeddings_2, confidences
226
+
227
+ def blend_embeddings(self, source_embeddings, target_embeddings, confidences):
228
+ """
229
+ Blend source and target embeddings based on configuration
230
 
231
+ Args:
232
+ source_embeddings: [batch_size, hidden_size]
233
+ target_embeddings: [batch_size, hidden_size]
234
+ confidences: [batch_size]
235
+
236
+ Returns:
237
+ blended_embeddings: [batch_size, hidden_size]
238
+ """
239
+ confidence_weights = confidences.unsqueeze(1)
240
 
241
+ if self.source_blend_mode == 'additive':
242
+ # ADDITIVE: enhanced = target + alpha * source
243
+ # Keeps target strength, adds source as enhancement
244
+ enhanced = target_embeddings + self.source_alpha * confidence_weights * source_embeddings
245
+ return enhanced
246
+ else:
247
+ # REPLACEMENT: blended = conf * source + (1-conf) * target
248
+ # Original soft confidence approach
249
+ blended = confidence_weights * source_embeddings + (1 - confidence_weights) * target_embeddings
250
+ return blended
251
+
252
+ def forward(
253
+ self,
254
+ input_ids,
255
+ input_ids_2,
256
+ target_mask,
257
+ target_mask_2,
258
+ attention_mask_2,
259
+ token_type_ids=None,
260
+ attention_mask=None,
261
+ labels=None,
262
+ head_mask=None,
263
+ input_with_mask_ids=None
264
+ ):
265
+ """
266
+ Forward pass with configurable source blending
267
+ """
268
+ # ===== ENCODER 1: Target in context =====
269
+ outputs = self.encoder(
270
+ input_ids,
271
+ token_type_ids=token_type_ids,
272
+ attention_mask=attention_mask,
273
+ head_mask=head_mask,
274
+ )
275
+
276
+ sequence_output = outputs[0]
277
+ pooled_output = outputs[1]
278
+
279
+ # Get target output with target mask
280
+ target_output = sequence_output * target_mask.unsqueeze(2)
281
+ target_output = self.dropout(target_output)
282
+ pooled_output = self.dropout(pooled_output)
283
+
284
+ if self.args.small_mean:
285
+ target_output = target_output.mean(1)
286
+ else:
287
+ target_output = target_output.sum(dim=1) / target_mask.sum(-1, keepdim=True)
288
+
289
+ # ===== ENCODER 2: Get source and target embeddings =====
290
+ source_embeddings, target_embeddings_2, confidences = self.predict_source_and_embeddings(
291
+ input_ids, target_mask, attention_mask,
292
+ input_ids_2, target_mask_2, attention_mask_2
293
+ )
294
+
295
+ # ===== METAPHOR-ONLY FILTERING (if enabled) =====
296
+ if self.source_use_mode == 'metaphor_only':
297
+ # Get preliminary metaphor score
298
+ # Use simple heuristic based on target context
299
+ prelim_features = torch.cat([pooled_output, target_output], dim=1)
300
+ prelim_hidden = self.SPV_linear(prelim_features)
301
+ prelim_logits = self.classifier(torch.cat([prelim_hidden, prelim_hidden], dim=1))
302
+ prelim_probs = torch.exp(self.logsoftmax(prelim_logits))
303
+ metaphor_scores = prelim_probs[:, 1] # Probability of metaphor class
304
+
305
+ # Only use source for samples with high metaphor probability
306
+ use_source_mask = (metaphor_scores > self.metaphor_threshold).float().unsqueeze(1)
307
+ else:
308
+ # Use source for all samples
309
+ use_source_mask = torch.ones(source_embeddings.size(0), 1).to(source_embeddings.device)
310
+
311
+ # ===== BLEND: Apply configured blending strategy =====
312
+ blended_embedding = self.blend_embeddings(source_embeddings, target_embeddings_2, confidences)
313
 
314
+ # Apply metaphor-only mask
315
+ final_embedding = use_source_mask * blended_embedding + (1 - use_source_mask) * target_embeddings_2
316
+ final_embedding = self.dropout(final_embedding)
317
+
318
+ # ===== SPV and MIP =====
319
+ if self.args.spv_isolate:
320
+ SPV_hidden = self.SPV_linear(torch.cat([pooled_output, final_embedding], dim=1))
321
+ else:
322
+ SPV_hidden = self.SPV_linear(torch.cat([pooled_output, target_output], dim=1))
323
 
324
+ MIP_hidden = self.MIP_linear(torch.cat([final_embedding, target_output], dim=1))
325
+
326
+ # Final classification
327
+ logits = self.classifier(self.dropout(torch.cat([SPV_hidden, MIP_hidden], dim=1)))
328
+ logits = self.logsoftmax(logits)
329
+
330
+ if labels is not None:
331
+ loss_fct = nn.NLLLoss()
332
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
333
+ return loss
334
+ return logits