File size: 11,372 Bytes
1b12abd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
"""
ConceptFrameMet: Metaphor Detection with Frame and Source Domain Prediction

This model detects metaphors and predicts their semantic frames and source domains.
Based on AdaptiveSourceQAMelBert architecture.
"""

import torch
import torch.nn as nn
from transformers import RobertaModel, RobertaTokenizer, AutoModelForQuestionAnswering, AutoTokenizer
from typing import Dict, List, Tuple, Optional
import json
import os


class ConceptFrameMetForMetaphorDetection(nn.Module):
    """
    Metaphor detection model with semantic frame and source domain prediction capabilities.
    
    This model:
    - Detects metaphors in text
    - Predicts semantic frames for target words
    - Predicts source domains for metaphors
    """
    
    def __init__(
        self,
        encoder_model_name="roberta-base",
        frame_qa_model_name="nixie1981/sem_frames",
        source_qa_model_name=None,
        classifier_hidden=768,
        drop_ratio=0.2,
        num_labels=2,
        source_blend_mode='replacement',
        source_use_mode='metaphor_only',
        source_alpha=0.3,
        metaphor_threshold=0.5,
    ):
        super().__init__()
        
        self.num_labels = num_labels
        self.classifier_hidden = classifier_hidden
        self.drop_ratio = drop_ratio
        
        # Configuration
        self.source_blend_mode = source_blend_mode
        self.source_use_mode = source_use_mode
        self.source_alpha = source_alpha
        self.metaphor_threshold = metaphor_threshold
        
        # Load encoder (RoBERTa)
        self.encoder = RobertaModel.from_pretrained(encoder_model_name)
        self.tokenizer = RobertaTokenizer.from_pretrained(encoder_model_name)
        self.config = self.encoder.config
        
        # Load frame QA model
        try:
            self.frame_qa_model = AutoModelForQuestionAnswering.from_pretrained(frame_qa_model_name)
            self.frame_qa_tokenizer = AutoTokenizer.from_pretrained(frame_qa_model_name)
            self.has_frame_predictor = True
        except:
            print("Warning: Frame QA model not available")
            self.has_frame_predictor = False
        
        # Load source QA model (if available)
        if source_qa_model_name:
            try:
                self.source_qa_model = AutoModelForQuestionAnswering.from_pretrained(source_qa_model_name)
                self.source_qa_tokenizer = AutoTokenizer.from_pretrained(source_qa_model_name)
                self.has_source_predictor = True
            except:
                print("Warning: Source QA model not available")
                self.has_source_predictor = False
        else:
            self.has_source_predictor = False
        
        # Dropout
        self.dropout = nn.Dropout(drop_ratio)
        
        # Classification layers
        self.SPV_linear = nn.Linear(self.config.hidden_size * 2, classifier_hidden)
        self.MIP_linear = nn.Linear(self.config.hidden_size * 2, classifier_hidden)
        self.classifier = nn.Linear(classifier_hidden * 2, num_labels)
        
        self._init_weights(self.SPV_linear)
        self._init_weights(self.MIP_linear)
        self._init_weights(self.classifier)
        
        self.logsoftmax = nn.LogSoftmax(dim=1)
        
        # Load source and frame labels
        self.source_id2label = {}
        self.frame_id2label = {}
        
    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
    
    def predict_frames(self, sentence: str, target_word: str) -> Dict[str, any]:
        """
        Predict semantic frame for a target word in context
        
        Args:
            sentence: Input sentence
            target_word: Target word to analyze
            
        Returns:
            Dictionary with frame prediction and confidence
        """
        if not self.has_frame_predictor:
            return {"frame": "UNKNOWN", "confidence": 0.0}
        
        inputs = self.frame_qa_tokenizer(
            sentence,
            target_word,
            max_length=150,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        with torch.no_grad():
            outputs = self.frame_qa_model(**inputs)
            start_logits = outputs.start_logits
            end_logits = outputs.end_logits
            
            start_idx = torch.argmax(start_logits)
            end_idx = torch.argmax(end_logits)
            
            confidence = (torch.max(torch.softmax(start_logits, dim=-1)) + 
                         torch.max(torch.softmax(end_logits, dim=-1))) / 2.0
            
            frame_tokens = inputs['input_ids'][0][start_idx:end_idx+1]
            frame = self.frame_qa_tokenizer.decode(frame_tokens, skip_special_tokens=True)
        
        return {
            "frame": frame if frame else "UNKNOWN",
            "confidence": confidence.item()
        }
    
    def predict_source(self, sentence: str, target_word: str) -> Dict[str, any]:
        """
        Predict source domain for a metaphor
        
        Args:
            sentence: Input sentence
            target_word: Target word to analyze
            
        Returns:
            Dictionary with source prediction and confidence
        """
        if not self.has_source_predictor:
            return {"source": "UNKNOWN", "confidence": 0.0}
        
        inputs = self.source_qa_tokenizer(
            sentence,
            target_word,
            max_length=150,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        with torch.no_grad():
            outputs = self.source_qa_model(**inputs)
            logits = outputs.logits if hasattr(outputs, 'logits') else outputs.start_logits
            
            probs = torch.softmax(logits, dim=-1)
            predicted_id = torch.argmax(probs, dim=-1)
            confidence = probs.gather(-1, predicted_id.unsqueeze(-1)).squeeze(-1)
            
            source = self.source_id2label.get(predicted_id.item(), "UNKNOWN")
        
        return {
            "source": source,
            "confidence": confidence.item()
        }
    
    def predict_metaphor(
        self,
        sentence: str,
        target_word: str,
        target_positions: Optional[List[int]] = None
    ) -> Dict[str, any]:
        """
        Predict if target word is metaphorical in context
        
        Args:
            sentence: Input sentence
            target_word: Target word to analyze
            target_positions: Token positions of target word (optional)
            
        Returns:
            Dictionary with metaphor prediction, frame, and source
        """
        # Tokenize input
        inputs = self.tokenizer(
            sentence,
            max_length=150,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Create target mask
        if target_positions is None:
            # Find target word positions
            target_tokens = self.tokenizer.tokenize(target_word)
            sentence_tokens = self.tokenizer.tokenize(sentence)
            target_positions = []
            for i in range(len(sentence_tokens) - len(target_tokens) + 1):
                if sentence_tokens[i:i+len(target_tokens)] == target_tokens:
                    target_positions = list(range(i+1, i+1+len(target_tokens)))  # +1 for CLS token
                    break
        
        target_mask = torch.zeros_like(inputs['input_ids'], dtype=torch.float)
        if target_positions:
            for pos in target_positions:
                if pos < target_mask.size(1):
                    target_mask[0, pos] = 1.0
        
        # Forward pass for metaphor detection
        with torch.no_grad():
            outputs = self.encoder(**inputs)
            sequence_output = outputs[0]
            pooled_output = outputs[1]
            
            # Get target output
            target_output = sequence_output * target_mask.unsqueeze(2)
            target_output = target_output.sum(dim=1) / (target_mask.sum(-1, keepdim=True) + 1e-10)
            target_output = self.dropout(target_output)
            pooled_output = self.dropout(pooled_output)
            
            # SPV and MIP
            SPV_hidden = self.SPV_linear(torch.cat([pooled_output, target_output], dim=1))
            MIP_hidden = self.MIP_linear(torch.cat([target_output, target_output], dim=1))
            
            # Classification
            logits = self.classifier(torch.cat([SPV_hidden, MIP_hidden], dim=1))
            logits = self.logsoftmax(logits)
            probs = torch.exp(logits)
            
            is_metaphor = torch.argmax(probs, dim=1).item() == 1
            metaphor_confidence = probs[0, 1].item()
        
        # Predict frame and source
        frame_result = self.predict_frames(sentence, target_word)
        source_result = self.predict_source(sentence, target_word) if is_metaphor else {"source": "N/A", "confidence": 0.0}
        
        return {
            "is_metaphor": is_metaphor,
            "metaphor_confidence": metaphor_confidence,
            "frame": frame_result["frame"],
            "frame_confidence": frame_result["confidence"],
            "source": source_result["source"],
            "source_confidence": source_result["confidence"]
        }
    
    @classmethod
    def from_pretrained(cls, model_path, **kwargs):
        """Load model from pretrained checkpoint"""
        # Load config
        config_path = os.path.join(model_path, "config.json")
        with open(config_path, 'r') as f:
            config = json.load(f)
        
        # Initialize model
        model = cls(**kwargs)
        
        # Load weights
        weights_path = os.path.join(model_path, "pytorch_model.bin")
        if os.path.exists(weights_path):
            state_dict = torch.load(weights_path, map_location='cpu')
            model.load_state_dict(state_dict, strict=False)
        
        return model
    
    def save_pretrained(self, save_directory):
        """Save model to directory"""
        os.makedirs(save_directory, exist_ok=True)
        
        # Save weights
        torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
        
        # Save config
        config = {
            "_name_or_path": "ConceptFrameMet",
            "architectures": ["ConceptFrameMetForMetaphorDetection"],
            "model_type": "conceptframemet",
            "num_labels": self.num_labels,
            "classifier_hidden": self.classifier_hidden,
            "drop_ratio": self.drop_ratio,
            "source_blend_mode": self.source_blend_mode,
            "source_use_mode": self.source_use_mode,
            "source_alpha": self.source_alpha,
            "metaphor_threshold": self.metaphor_threshold,
        }
        
        with open(os.path.join(save_directory, "config.json"), 'w') as f:
            json.dump(config, f, indent=2)
        
        # Save tokenizer
        self.tokenizer.save_pretrained(save_directory)