| """ |
| ConceptFrameMet: Metaphor Detection with Frame and Source Domain Prediction |
| |
| This model detects metaphors and predicts their semantic frames and source domains. |
| Based on AdaptiveSourceQAMelBert architecture. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import RobertaModel, RobertaTokenizer, AutoModelForQuestionAnswering, AutoTokenizer |
| from typing import Dict, List, Tuple, Optional |
| import json |
| import os |
|
|
|
|
| class ConceptFrameMetForMetaphorDetection(nn.Module): |
| """ |
| Metaphor detection model with semantic frame and source domain prediction capabilities. |
| |
| This model: |
| - Detects metaphors in text |
| - Predicts semantic frames for target words |
| - Predicts source domains for metaphors |
| """ |
| |
| def __init__( |
| self, |
| encoder_model_name="roberta-base", |
| frame_qa_model_name="nixie1981/sem_frames", |
| source_qa_model_name=None, |
| classifier_hidden=768, |
| drop_ratio=0.2, |
| num_labels=2, |
| source_blend_mode='replacement', |
| source_use_mode='metaphor_only', |
| source_alpha=0.3, |
| metaphor_threshold=0.5, |
| ): |
| super().__init__() |
| |
| self.num_labels = num_labels |
| self.classifier_hidden = classifier_hidden |
| self.drop_ratio = drop_ratio |
| |
| |
| self.source_blend_mode = source_blend_mode |
| self.source_use_mode = source_use_mode |
| self.source_alpha = source_alpha |
| self.metaphor_threshold = metaphor_threshold |
| |
| |
| self.encoder = RobertaModel.from_pretrained(encoder_model_name) |
| self.tokenizer = RobertaTokenizer.from_pretrained(encoder_model_name) |
| self.config = self.encoder.config |
| |
| |
| try: |
| self.frame_qa_model = AutoModelForQuestionAnswering.from_pretrained(frame_qa_model_name) |
| self.frame_qa_tokenizer = AutoTokenizer.from_pretrained(frame_qa_model_name) |
| self.has_frame_predictor = True |
| except: |
| print("Warning: Frame QA model not available") |
| self.has_frame_predictor = False |
| |
| |
| if source_qa_model_name: |
| try: |
| self.source_qa_model = AutoModelForQuestionAnswering.from_pretrained(source_qa_model_name) |
| self.source_qa_tokenizer = AutoTokenizer.from_pretrained(source_qa_model_name) |
| self.has_source_predictor = True |
| except: |
| print("Warning: Source QA model not available") |
| self.has_source_predictor = False |
| else: |
| self.has_source_predictor = False |
| |
| |
| self.dropout = nn.Dropout(drop_ratio) |
| |
| |
| self.SPV_linear = nn.Linear(self.config.hidden_size * 2, classifier_hidden) |
| self.MIP_linear = nn.Linear(self.config.hidden_size * 2, classifier_hidden) |
| self.classifier = nn.Linear(classifier_hidden * 2, num_labels) |
| |
| self._init_weights(self.SPV_linear) |
| self._init_weights(self.MIP_linear) |
| self._init_weights(self.classifier) |
| |
| self.logsoftmax = nn.LogSoftmax(dim=1) |
| |
| |
| self.source_id2label = {} |
| self.frame_id2label = {} |
| |
| def _init_weights(self, module): |
| """Initialize the weights""" |
| if isinstance(module, (nn.Linear, nn.Embedding)): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if isinstance(module, nn.Linear) and module.bias is not None: |
| module.bias.data.zero_() |
| |
| def predict_frames(self, sentence: str, target_word: str) -> Dict[str, any]: |
| """ |
| Predict semantic frame for a target word in context |
| |
| Args: |
| sentence: Input sentence |
| target_word: Target word to analyze |
| |
| Returns: |
| Dictionary with frame prediction and confidence |
| """ |
| if not self.has_frame_predictor: |
| return {"frame": "UNKNOWN", "confidence": 0.0} |
| |
| inputs = self.frame_qa_tokenizer( |
| sentence, |
| target_word, |
| max_length=150, |
| padding='max_length', |
| truncation=True, |
| return_tensors='pt' |
| ) |
| |
| with torch.no_grad(): |
| outputs = self.frame_qa_model(**inputs) |
| start_logits = outputs.start_logits |
| end_logits = outputs.end_logits |
| |
| start_idx = torch.argmax(start_logits) |
| end_idx = torch.argmax(end_logits) |
| |
| confidence = (torch.max(torch.softmax(start_logits, dim=-1)) + |
| torch.max(torch.softmax(end_logits, dim=-1))) / 2.0 |
| |
| frame_tokens = inputs['input_ids'][0][start_idx:end_idx+1] |
| frame = self.frame_qa_tokenizer.decode(frame_tokens, skip_special_tokens=True) |
| |
| return { |
| "frame": frame if frame else "UNKNOWN", |
| "confidence": confidence.item() |
| } |
| |
| def predict_source(self, sentence: str, target_word: str) -> Dict[str, any]: |
| """ |
| Predict source domain for a metaphor |
| |
| Args: |
| sentence: Input sentence |
| target_word: Target word to analyze |
| |
| Returns: |
| Dictionary with source prediction and confidence |
| """ |
| if not self.has_source_predictor: |
| return {"source": "UNKNOWN", "confidence": 0.0} |
| |
| inputs = self.source_qa_tokenizer( |
| sentence, |
| target_word, |
| max_length=150, |
| padding='max_length', |
| truncation=True, |
| return_tensors='pt' |
| ) |
| |
| with torch.no_grad(): |
| outputs = self.source_qa_model(**inputs) |
| logits = outputs.logits if hasattr(outputs, 'logits') else outputs.start_logits |
| |
| probs = torch.softmax(logits, dim=-1) |
| predicted_id = torch.argmax(probs, dim=-1) |
| confidence = probs.gather(-1, predicted_id.unsqueeze(-1)).squeeze(-1) |
| |
| source = self.source_id2label.get(predicted_id.item(), "UNKNOWN") |
| |
| return { |
| "source": source, |
| "confidence": confidence.item() |
| } |
| |
| def predict_metaphor( |
| self, |
| sentence: str, |
| target_word: str, |
| target_positions: Optional[List[int]] = None |
| ) -> Dict[str, any]: |
| """ |
| Predict if target word is metaphorical in context |
| |
| Args: |
| sentence: Input sentence |
| target_word: Target word to analyze |
| target_positions: Token positions of target word (optional) |
| |
| Returns: |
| Dictionary with metaphor prediction, frame, and source |
| """ |
| |
| inputs = self.tokenizer( |
| sentence, |
| max_length=150, |
| padding='max_length', |
| truncation=True, |
| return_tensors='pt' |
| ) |
| |
| |
| if target_positions is None: |
| |
| target_tokens = self.tokenizer.tokenize(target_word) |
| sentence_tokens = self.tokenizer.tokenize(sentence) |
| target_positions = [] |
| for i in range(len(sentence_tokens) - len(target_tokens) + 1): |
| if sentence_tokens[i:i+len(target_tokens)] == target_tokens: |
| target_positions = list(range(i+1, i+1+len(target_tokens))) |
| break |
| |
| target_mask = torch.zeros_like(inputs['input_ids'], dtype=torch.float) |
| if target_positions: |
| for pos in target_positions: |
| if pos < target_mask.size(1): |
| target_mask[0, pos] = 1.0 |
| |
| |
| with torch.no_grad(): |
| outputs = self.encoder(**inputs) |
| sequence_output = outputs[0] |
| pooled_output = outputs[1] |
| |
| |
| target_output = sequence_output * target_mask.unsqueeze(2) |
| target_output = target_output.sum(dim=1) / (target_mask.sum(-1, keepdim=True) + 1e-10) |
| target_output = self.dropout(target_output) |
| pooled_output = self.dropout(pooled_output) |
| |
| |
| SPV_hidden = self.SPV_linear(torch.cat([pooled_output, target_output], dim=1)) |
| MIP_hidden = self.MIP_linear(torch.cat([target_output, target_output], dim=1)) |
| |
| |
| logits = self.classifier(torch.cat([SPV_hidden, MIP_hidden], dim=1)) |
| logits = self.logsoftmax(logits) |
| probs = torch.exp(logits) |
| |
| is_metaphor = torch.argmax(probs, dim=1).item() == 1 |
| metaphor_confidence = probs[0, 1].item() |
| |
| |
| frame_result = self.predict_frames(sentence, target_word) |
| source_result = self.predict_source(sentence, target_word) if is_metaphor else {"source": "N/A", "confidence": 0.0} |
| |
| return { |
| "is_metaphor": is_metaphor, |
| "metaphor_confidence": metaphor_confidence, |
| "frame": frame_result["frame"], |
| "frame_confidence": frame_result["confidence"], |
| "source": source_result["source"], |
| "source_confidence": source_result["confidence"] |
| } |
| |
| @classmethod |
| def from_pretrained(cls, model_path, **kwargs): |
| """Load model from pretrained checkpoint""" |
| |
| config_path = os.path.join(model_path, "config.json") |
| with open(config_path, 'r') as f: |
| config = json.load(f) |
| |
| |
| model = cls(**kwargs) |
| |
| |
| weights_path = os.path.join(model_path, "pytorch_model.bin") |
| if os.path.exists(weights_path): |
| state_dict = torch.load(weights_path, map_location='cpu') |
| model.load_state_dict(state_dict, strict=False) |
| |
| return model |
| |
| def save_pretrained(self, save_directory): |
| """Save model to directory""" |
| os.makedirs(save_directory, exist_ok=True) |
| |
| |
| torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) |
| |
| |
| config = { |
| "_name_or_path": "ConceptFrameMet", |
| "architectures": ["ConceptFrameMetForMetaphorDetection"], |
| "model_type": "conceptframemet", |
| "num_labels": self.num_labels, |
| "classifier_hidden": self.classifier_hidden, |
| "drop_ratio": self.drop_ratio, |
| "source_blend_mode": self.source_blend_mode, |
| "source_use_mode": self.source_use_mode, |
| "source_alpha": self.source_alpha, |
| "metaphor_threshold": self.metaphor_threshold, |
| } |
| |
| with open(os.path.join(save_directory, "config.json"), 'w') as f: |
| json.dump(config, f, indent=2) |
| |
| |
| self.tokenizer.save_pretrained(save_directory) |
|
|