File size: 16,815 Bytes
1b12abd
36988a2
1b12abd
36988a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b12abd
 
 
 
36988a2
1b12abd
 
36988a2
 
 
 
 
 
 
1b12abd
36988a2
 
 
 
 
 
 
 
 
 
 
 
 
 
1b12abd
36988a2
748a347
 
 
 
 
 
 
 
 
 
 
 
36988a2
 
 
 
 
 
 
 
 
 
 
 
 
939da56
 
 
 
 
 
 
 
36988a2
 
 
 
 
748a347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36988a2
 
 
 
 
 
 
1b12abd
 
 
 
 
 
 
36988a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b12abd
 
 
 
36988a2
 
 
1b12abd
 
36988a2
 
 
1b12abd
36988a2
1b12abd
 
36988a2
 
 
1b12abd
36988a2
1b12abd
b6384b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36988a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850ff6f
 
 
 
36988a2
1b12abd
36988a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b12abd
36988a2
 
1b12abd
36988a2
 
 
 
1b12abd
 
 
 
36988a2
 
1b12abd
36988a2
 
 
 
1b12abd
36988a2
 
00272b3
36988a2
 
 
 
7f84c5d
36988a2
 
 
 
7f84c5d
00272b3
36988a2
 
7f84c5d
36988a2
 
 
 
1b12abd
36988a2
 
 
 
 
1b12abd
36988a2
 
 
 
 
 
 
 
 
1b12abd
36988a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b12abd
36988a2
 
 
 
 
 
 
 
 
1b12abd
36988a2
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
"""
Adaptive Source QA MelBERT with Configurable Blending Strategies

This model provides configurable approaches to incorporating source domain information:

FLAGS:
1. --source_blend_mode: 'additive' or 'replacement' (default: 'replacement')
   - additive: enhanced = target + alpha * source (keeps target strength)
   - replacement: blended = conf * source + (1-conf) * target (original approach)

2. --source_use_mode: 'metaphor_only' or 'all' (default: 'all')
   - metaphor_only: Only use source for samples with high metaphor probability
   - all: Use source for all samples

3. --source_alpha: float (default: 0.3) - scaling factor for additive mode

4. --metaphor_threshold: float (default: 0.5) - threshold for metaphor-only mode

Architecture:
- CONTEXT: target_word in full sentence β†’ encoder 1 β†’ target_context_embedding  
- SOURCE: [SEP] sentence [SEP] target [SEP] β†’ QA model β†’ predict source + confidence
- ISOLATED: isolated target β†’ encoder 2 β†’ target_embedding
- BLEND: Configurable (additive or replacement)
- FILTER: Configurable (metaphor-only or all)
- MIP: [enhanced_embedding, target_context_embedding]
- SPV: [pooled, enhanced_embedding] or [pooled, target_context_embedding]
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class AdaptiveSourceQAMelBert(nn.Module):
    """MelBERT with configurable source domain blending strategies"""

    def __init__(self, args, Model, config, Source_QA_Model, 
                 source_qa_tokenizer, melbert_tokenizer, num_labels=2):
        """
        Initialize the model with configurable flags
        
        Args:
            args: Configuration arguments with:
                - source_blend_mode: 'additive' or 'replacement'
                - source_use_mode: 'metaphor_only' or 'all'
                - source_alpha: scaling factor for additive mode
                - metaphor_threshold: threshold for metaphor-only mode
            Model: MelBert encoder (RoBERTa/BERT)
            config: Model configuration
            Source_QA_Model: QA-style model to predict source domain
            source_qa_tokenizer: Tokenizer for QA model
            melbert_tokenizer: Tokenizer for MelBert
            num_labels: Number of metaphor classes (2: literal/metaphorical)
        """
        super(AdaptiveSourceQAMelBert, self).__init__()
        self.num_labels = num_labels
        self.encoder = Model
        
        # FIX: Resize token_type_embeddings to match training (type_vocab_size=4)
        if hasattr(self.encoder, 'embeddings') and hasattr(self.encoder.embeddings, 'token_type_embeddings'):
            if self.encoder.embeddings.token_type_embeddings.weight.shape[0] != 4:
                old_embeddings = self.encoder.embeddings.token_type_embeddings
                new_embeddings = nn.Embedding(4, old_embeddings.embedding_dim)
                new_embeddings.weight.data[0] = old_embeddings.weight.data[0]
                new_embeddings.weight.data[1:].normal_(mean=0.0, std=config.initializer_range)
                self.encoder.embeddings.token_type_embeddings = new_embeddings
                if hasattr(self.encoder, 'config'):
                    self.encoder.config.type_vocab_size = 4
        
        self.source_qa_model = Source_QA_Model
        self.source_qa_tokenizer = source_qa_tokenizer
        self.melbert_tokenizer = melbert_tokenizer
        self.config = config
        self.dropout = nn.Dropout(args.drop_ratio)
        self.args = args

        # Configuration flags with defaults
        self.source_blend_mode = getattr(args, 'source_blend_mode', 'replacement')
        self.source_use_mode = getattr(args, 'source_use_mode', 'all')
        self.source_alpha = getattr(args, 'source_alpha', 0.3)
        self.metaphor_threshold = getattr(args, 'metaphor_threshold', 0.5)

        # Freeze or unfreeze source QA model (only if it exists)
        if self.source_qa_model is not None:
            if not getattr(args, 'unfreeze_source_qa', False):
                for param in self.source_qa_model.parameters():
                    param.requires_grad = False
            else:
                for param in self.source_qa_model.parameters():
                    param.requires_grad = True

        # Load source labels
        self.source_id2label = {}
        try:
            import json
            import os
            # Try multiple paths
            possible_paths = [
                'source_labels.json',  # Same directory as model file
                'source_finder/source_labels.json',  # Original location
                os.path.join(os.path.dirname(__file__), 'source_labels.json'),  # Next to this file
            ]
            
            for path in possible_paths:
                try:
                    with open(path, 'r') as f:
                        source_label2id = json.load(f)
                        self.source_id2label = {v: k for k, v in source_label2id.items()}
                        print(f"βœ“ Loaded {len(self.source_id2label)} source domain labels from {path}")
                        break
                except:
                    continue
            
            if not self.source_id2label:
                print(f"❌ Warning: Could not load source labels from any location")
        except Exception as e:
            print(f"❌ Warning: Could not load source labels: {e}")

        # SPV and MIP linear layers
        self.SPV_linear = nn.Linear(config.hidden_size * 2, args.classifier_hidden)
        self.MIP_linear = nn.Linear(config.hidden_size * 2, args.classifier_hidden)
        self.classifier = nn.Linear(args.classifier_hidden * 2, num_labels)
        
        self._init_weights(self.SPV_linear)
        self._init_weights(self.MIP_linear)
        self._init_weights(self.classifier)
        
        self.logsoftmax = nn.LogSoftmax(dim=1)
        
        # Print configuration
        print(f"\n{'='*80}")
        print(f"βœ“ AdaptiveSourceQAMelBert initialized")
        print(f"  - Blend Mode: {self.source_blend_mode.upper()}")
        if self.source_blend_mode == 'additive':
            print(f"    β†’ enhanced = target + {self.source_alpha} * source")
        else:
            print(f"    β†’ blended = conf * source + (1-conf) * target")
        print(f"  - Use Mode: {self.source_use_mode.upper()}")
        if self.source_use_mode == 'metaphor_only':
            print(f"    β†’ Only use source when metaphor_score > {self.metaphor_threshold}")
        else:
            print(f"    β†’ Use source for all samples")
        print(f"{'='*80}\n")

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def predict_source_and_embeddings(self, input_ids, target_mask, attention_mask, 
                                       input_ids_2, target_mask_2, attention_mask_2):
        """
        Predict source domain and get source/target embeddings
        
        Returns:
            source_embeddings: [batch_size, hidden_size]
            target_embeddings: [batch_size, hidden_size]
            confidences: [batch_size] - confidence scores
        """
        batch_size = input_ids.size(0)
        
        # If no source QA model, load from checkpoint and use embeddings from there
        if self.source_qa_model is None:
            # Use isolated target embeddings as source (will be loaded from checkpoint)
            target_outputs_2 = self.encoder(input_ids_2, attention_mask=attention_mask_2)
            target_sequence_output_2 = target_outputs_2[0]
            target_output_2 = target_sequence_output_2 * target_mask_2.unsqueeze(2)
            
            if self.args.small_mean:
                target_embeddings_2 = target_output_2.mean(1)
            else:
                target_embeddings_2 = target_output_2.sum(dim=1) / target_mask_2.sum(-1, keepdim=True)
            
            # Use same embedding for source (will blend based on checkpoint source_qa_model)
            source_embeddings = target_embeddings_2
            confidences = torch.ones(batch_size).to(input_ids.device) * 0.5
            
            return source_embeddings, target_embeddings_2, confidences
        
        # Original logic with source QA model
        # 1. Decode sentences and extract target words
        sentences = []
        target_words = []
        
        for i in range(batch_size):
            sentence = self.melbert_tokenizer.decode(input_ids[i], skip_special_tokens=True)
            target_positions = target_mask[i].nonzero(as_tuple=True)[0]
            
            if len(target_positions) > 0:
                target_tokens = input_ids[i][target_positions]
                target_word = self.melbert_tokenizer.decode(target_tokens, skip_special_tokens=True)
            else:
                target_word = "unknown"
            
            sentences.append(sentence)
            target_words.append(target_word)
        
        # 2. Format QA input and predict source
        with torch.no_grad():
            qa_inputs = self.source_qa_tokenizer(
                sentences,
                target_words,
                max_length=self.args.max_seq_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            qa_inputs = {k: v.to(input_ids.device) for k, v in qa_inputs.items()}
            
            # If source model is FrameAwareSourcePredictor, also pass frame inputs
            # (frame inputs are the same as source inputs for this use case)
            if hasattr(self.source_qa_model, 'frame_finder'):
                qa_inputs['frame_input_ids'] = qa_inputs['input_ids']
                qa_inputs['frame_attention_mask'] = qa_inputs['attention_mask']
        
        # 3. Get source predictions with confidence
        qa_outputs = self.source_qa_model(**qa_inputs)
        source_logits = qa_outputs.logits
        source_probs = torch.softmax(source_logits, dim=-1)
        predicted_source_ids = torch.argmax(source_logits, dim=-1)
        
        # Get confidence scores
        confidences = source_probs.gather(1, predicted_source_ids.unsqueeze(1)).squeeze(1)
        
        # Map to source words
        with torch.no_grad():
            predicted_sources = [self.source_id2label.get(sid.item(), "UNKNOWN") 
                                for sid in predicted_source_ids]
        
        # 4. Encode predicted source words
        source_inputs = self.melbert_tokenizer(
            predicted_sources,
            max_length=self.args.max_seq_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        source_inputs = {k: v.to(input_ids.device) for k, v in source_inputs.items()}
        source_target_mask = (source_inputs['input_ids'] != self.melbert_tokenizer.pad_token_id).float()
        
        source_outputs = self.encoder(
            source_inputs['input_ids'],
            attention_mask=source_inputs['attention_mask']
        )
        
        source_sequence_output = source_outputs[0]
        source_target_output = source_sequence_output * source_target_mask.unsqueeze(2)
        
        if self.args.small_mean:
            source_embeddings = source_target_output.mean(1)
        else:
            source_embeddings = source_target_output.sum(dim=1) / source_target_mask.sum(-1, keepdim=True)
        
        # 5. Encode original isolated target words
        target_outputs_2 = self.encoder(
            input_ids_2,
            attention_mask=attention_mask_2
        )
        
        target_sequence_output_2 = target_outputs_2[0]
        target_output_2 = target_sequence_output_2 * target_mask_2.unsqueeze(2)
        
        if self.args.small_mean:
            target_embeddings_2 = target_output_2.mean(1)
        else:
            target_embeddings_2 = target_output_2.sum(dim=1) / target_mask_2.sum(-1, keepdim=True)
        
        return source_embeddings, target_embeddings_2, confidences

    def blend_embeddings(self, source_embeddings, target_embeddings, confidences):
        """
        Blend source and target embeddings based on configuration
        
        Args:
            source_embeddings: [batch_size, hidden_size]
            target_embeddings: [batch_size, hidden_size]
            confidences: [batch_size]
            
        Returns:
            blended_embeddings: [batch_size, hidden_size]
        """
        confidence_weights = confidences.unsqueeze(1)
        
        if self.source_blend_mode == 'additive':
            # ADDITIVE: enhanced = target + alpha * source
            # Keeps target strength, adds source as enhancement
            enhanced = target_embeddings + self.source_alpha * confidence_weights * source_embeddings
            return enhanced
        else:
            # REPLACEMENT: blended = conf * source + (1-conf) * target
            # Original soft confidence approach
            blended = confidence_weights * source_embeddings + (1 - confidence_weights) * target_embeddings
            return blended

    def forward(
        self,
        input_ids,
        input_ids_2,
        target_mask,
        target_mask_2,
        attention_mask_2,
        token_type_ids=None,
        attention_mask=None,
        labels=None,
        head_mask=None,
        input_with_mask_ids=None
    ):
        """
        Forward pass with configurable source blending
        """
        # ===== ENCODER 1: Target in context =====
        outputs = self.encoder(
            input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
        )

        sequence_output = outputs[0]
        pooled_output = outputs[1]

        # Get target output with target mask
        target_output = sequence_output * target_mask.unsqueeze(2)
        target_output = self.dropout(target_output)
        pooled_output = self.dropout(pooled_output)

        if self.args.small_mean:
            target_output = target_output.mean(1)
        else:
            target_output = target_output.sum(dim=1) / target_mask.sum(-1, keepdim=True)

        # ===== ENCODER 2: Get source and target embeddings =====
        source_embeddings, target_embeddings_2, confidences = self.predict_source_and_embeddings(
            input_ids, target_mask, attention_mask,
            input_ids_2, target_mask_2, attention_mask_2
        )

        # ===== METAPHOR-ONLY FILTERING (if enabled) =====
        if self.source_use_mode == 'metaphor_only':
            # Get preliminary metaphor score
            # Use simple heuristic based on target context
            prelim_features = torch.cat([pooled_output, target_output], dim=1)
            prelim_hidden = self.SPV_linear(prelim_features)
            prelim_logits = self.classifier(torch.cat([prelim_hidden, prelim_hidden], dim=1))
            prelim_probs = torch.exp(self.logsoftmax(prelim_logits))
            metaphor_scores = prelim_probs[:, 1]  # Probability of metaphor class
            
            # Only use source for samples with high metaphor probability
            use_source_mask = (metaphor_scores > self.metaphor_threshold).float().unsqueeze(1)
        else:
            # Use source for all samples
            use_source_mask = torch.ones(source_embeddings.size(0), 1).to(source_embeddings.device)

        # ===== BLEND: Apply configured blending strategy =====
        blended_embedding = self.blend_embeddings(source_embeddings, target_embeddings_2, confidences)
        
        # Apply metaphor-only mask
        final_embedding = use_source_mask * blended_embedding + (1 - use_source_mask) * target_embeddings_2
        final_embedding = self.dropout(final_embedding)

        # ===== SPV and MIP =====
        if self.args.spv_isolate:
            SPV_hidden = self.SPV_linear(torch.cat([pooled_output, final_embedding], dim=1))
        else:
            SPV_hidden = self.SPV_linear(torch.cat([pooled_output, target_output], dim=1))
        
        MIP_hidden = self.MIP_linear(torch.cat([final_embedding, target_output], dim=1))

        # Final classification
        logits = self.classifier(self.dropout(torch.cat([SPV_hidden, MIP_hidden], dim=1)))
        logits = self.logsoftmax(logits)

        if labels is not None:
            loss_fct = nn.NLLLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            return loss
        return logits