""" ConceptFrameMet: Metaphor Detection with Frame and Source Domain Prediction This model detects metaphors and predicts their semantic frames and source domains. Based on AdaptiveSourceQAMelBert architecture. """ import torch import torch.nn as nn from transformers import RobertaModel, RobertaTokenizer, AutoModelForQuestionAnswering, AutoTokenizer from typing import Dict, List, Tuple, Optional import json import os class ConceptFrameMetForMetaphorDetection(nn.Module): """ Metaphor detection model with semantic frame and source domain prediction capabilities. This model: - Detects metaphors in text - Predicts semantic frames for target words - Predicts source domains for metaphors """ def __init__( self, encoder_model_name="roberta-base", frame_qa_model_name="nixie1981/sem_frames", source_qa_model_name=None, classifier_hidden=768, drop_ratio=0.2, num_labels=2, source_blend_mode='replacement', source_use_mode='metaphor_only', source_alpha=0.3, metaphor_threshold=0.5, ): super().__init__() self.num_labels = num_labels self.classifier_hidden = classifier_hidden self.drop_ratio = drop_ratio # Configuration self.source_blend_mode = source_blend_mode self.source_use_mode = source_use_mode self.source_alpha = source_alpha self.metaphor_threshold = metaphor_threshold # Load encoder (RoBERTa) self.encoder = RobertaModel.from_pretrained(encoder_model_name) self.tokenizer = RobertaTokenizer.from_pretrained(encoder_model_name) self.config = self.encoder.config # Load frame QA model try: self.frame_qa_model = AutoModelForQuestionAnswering.from_pretrained(frame_qa_model_name) self.frame_qa_tokenizer = AutoTokenizer.from_pretrained(frame_qa_model_name) self.has_frame_predictor = True except: print("Warning: Frame QA model not available") self.has_frame_predictor = False # Load source QA model (if available) if source_qa_model_name: try: self.source_qa_model = AutoModelForQuestionAnswering.from_pretrained(source_qa_model_name) self.source_qa_tokenizer = AutoTokenizer.from_pretrained(source_qa_model_name) self.has_source_predictor = True except: print("Warning: Source QA model not available") self.has_source_predictor = False else: self.has_source_predictor = False # Dropout self.dropout = nn.Dropout(drop_ratio) # Classification layers self.SPV_linear = nn.Linear(self.config.hidden_size * 2, classifier_hidden) self.MIP_linear = nn.Linear(self.config.hidden_size * 2, classifier_hidden) self.classifier = nn.Linear(classifier_hidden * 2, num_labels) self._init_weights(self.SPV_linear) self._init_weights(self.MIP_linear) self._init_weights(self.classifier) self.logsoftmax = nn.LogSoftmax(dim=1) # Load source and frame labels self.source_id2label = {} self.frame_id2label = {} def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def predict_frames(self, sentence: str, target_word: str) -> Dict[str, any]: """ Predict semantic frame for a target word in context Args: sentence: Input sentence target_word: Target word to analyze Returns: Dictionary with frame prediction and confidence """ if not self.has_frame_predictor: return {"frame": "UNKNOWN", "confidence": 0.0} inputs = self.frame_qa_tokenizer( sentence, target_word, max_length=150, padding='max_length', truncation=True, return_tensors='pt' ) with torch.no_grad(): outputs = self.frame_qa_model(**inputs) start_logits = outputs.start_logits end_logits = outputs.end_logits start_idx = torch.argmax(start_logits) end_idx = torch.argmax(end_logits) confidence = (torch.max(torch.softmax(start_logits, dim=-1)) + torch.max(torch.softmax(end_logits, dim=-1))) / 2.0 frame_tokens = inputs['input_ids'][0][start_idx:end_idx+1] frame = self.frame_qa_tokenizer.decode(frame_tokens, skip_special_tokens=True) return { "frame": frame if frame else "UNKNOWN", "confidence": confidence.item() } def predict_source(self, sentence: str, target_word: str) -> Dict[str, any]: """ Predict source domain for a metaphor Args: sentence: Input sentence target_word: Target word to analyze Returns: Dictionary with source prediction and confidence """ if not self.has_source_predictor: return {"source": "UNKNOWN", "confidence": 0.0} inputs = self.source_qa_tokenizer( sentence, target_word, max_length=150, padding='max_length', truncation=True, return_tensors='pt' ) with torch.no_grad(): outputs = self.source_qa_model(**inputs) logits = outputs.logits if hasattr(outputs, 'logits') else outputs.start_logits probs = torch.softmax(logits, dim=-1) predicted_id = torch.argmax(probs, dim=-1) confidence = probs.gather(-1, predicted_id.unsqueeze(-1)).squeeze(-1) source = self.source_id2label.get(predicted_id.item(), "UNKNOWN") return { "source": source, "confidence": confidence.item() } def predict_metaphor( self, sentence: str, target_word: str, target_positions: Optional[List[int]] = None ) -> Dict[str, any]: """ Predict if target word is metaphorical in context Args: sentence: Input sentence target_word: Target word to analyze target_positions: Token positions of target word (optional) Returns: Dictionary with metaphor prediction, frame, and source """ # Tokenize input inputs = self.tokenizer( sentence, max_length=150, padding='max_length', truncation=True, return_tensors='pt' ) # Create target mask if target_positions is None: # Find target word positions target_tokens = self.tokenizer.tokenize(target_word) sentence_tokens = self.tokenizer.tokenize(sentence) target_positions = [] for i in range(len(sentence_tokens) - len(target_tokens) + 1): if sentence_tokens[i:i+len(target_tokens)] == target_tokens: target_positions = list(range(i+1, i+1+len(target_tokens))) # +1 for CLS token break target_mask = torch.zeros_like(inputs['input_ids'], dtype=torch.float) if target_positions: for pos in target_positions: if pos < target_mask.size(1): target_mask[0, pos] = 1.0 # Forward pass for metaphor detection with torch.no_grad(): outputs = self.encoder(**inputs) sequence_output = outputs[0] pooled_output = outputs[1] # Get target output target_output = sequence_output * target_mask.unsqueeze(2) target_output = target_output.sum(dim=1) / (target_mask.sum(-1, keepdim=True) + 1e-10) target_output = self.dropout(target_output) pooled_output = self.dropout(pooled_output) # SPV and MIP SPV_hidden = self.SPV_linear(torch.cat([pooled_output, target_output], dim=1)) MIP_hidden = self.MIP_linear(torch.cat([target_output, target_output], dim=1)) # Classification logits = self.classifier(torch.cat([SPV_hidden, MIP_hidden], dim=1)) logits = self.logsoftmax(logits) probs = torch.exp(logits) is_metaphor = torch.argmax(probs, dim=1).item() == 1 metaphor_confidence = probs[0, 1].item() # Predict frame and source frame_result = self.predict_frames(sentence, target_word) source_result = self.predict_source(sentence, target_word) if is_metaphor else {"source": "N/A", "confidence": 0.0} return { "is_metaphor": is_metaphor, "metaphor_confidence": metaphor_confidence, "frame": frame_result["frame"], "frame_confidence": frame_result["confidence"], "source": source_result["source"], "source_confidence": source_result["confidence"] } @classmethod def from_pretrained(cls, model_path, **kwargs): """Load model from pretrained checkpoint""" # Load config config_path = os.path.join(model_path, "config.json") with open(config_path, 'r') as f: config = json.load(f) # Initialize model model = cls(**kwargs) # Load weights weights_path = os.path.join(model_path, "pytorch_model.bin") if os.path.exists(weights_path): state_dict = torch.load(weights_path, map_location='cpu') model.load_state_dict(state_dict, strict=False) return model def save_pretrained(self, save_directory): """Save model to directory""" os.makedirs(save_directory, exist_ok=True) # Save weights torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) # Save config config = { "_name_or_path": "ConceptFrameMet", "architectures": ["ConceptFrameMetForMetaphorDetection"], "model_type": "conceptframemet", "num_labels": self.num_labels, "classifier_hidden": self.classifier_hidden, "drop_ratio": self.drop_ratio, "source_blend_mode": self.source_blend_mode, "source_use_mode": self.source_use_mode, "source_alpha": self.source_alpha, "metaphor_threshold": self.metaphor_threshold, } with open(os.path.join(save_directory, "config.json"), 'w') as f: json.dump(config, f, indent=2) # Save tokenizer self.tokenizer.save_pretrained(save_directory)