ConceptFrameMet / modeling_conceptframemet.py
nixie1981's picture
Upload folder using huggingface_hub
1b12abd verified
raw
history blame
11.4 kB
"""
ConceptFrameMet: Metaphor Detection with Frame and Source Domain Prediction
This model detects metaphors and predicts their semantic frames and source domains.
Based on AdaptiveSourceQAMelBert architecture.
"""
import torch
import torch.nn as nn
from transformers import RobertaModel, RobertaTokenizer, AutoModelForQuestionAnswering, AutoTokenizer
from typing import Dict, List, Tuple, Optional
import json
import os
class ConceptFrameMetForMetaphorDetection(nn.Module):
"""
Metaphor detection model with semantic frame and source domain prediction capabilities.
This model:
- Detects metaphors in text
- Predicts semantic frames for target words
- Predicts source domains for metaphors
"""
def __init__(
self,
encoder_model_name="roberta-base",
frame_qa_model_name="nixie1981/sem_frames",
source_qa_model_name=None,
classifier_hidden=768,
drop_ratio=0.2,
num_labels=2,
source_blend_mode='replacement',
source_use_mode='metaphor_only',
source_alpha=0.3,
metaphor_threshold=0.5,
):
super().__init__()
self.num_labels = num_labels
self.classifier_hidden = classifier_hidden
self.drop_ratio = drop_ratio
# Configuration
self.source_blend_mode = source_blend_mode
self.source_use_mode = source_use_mode
self.source_alpha = source_alpha
self.metaphor_threshold = metaphor_threshold
# Load encoder (RoBERTa)
self.encoder = RobertaModel.from_pretrained(encoder_model_name)
self.tokenizer = RobertaTokenizer.from_pretrained(encoder_model_name)
self.config = self.encoder.config
# Load frame QA model
try:
self.frame_qa_model = AutoModelForQuestionAnswering.from_pretrained(frame_qa_model_name)
self.frame_qa_tokenizer = AutoTokenizer.from_pretrained(frame_qa_model_name)
self.has_frame_predictor = True
except:
print("Warning: Frame QA model not available")
self.has_frame_predictor = False
# Load source QA model (if available)
if source_qa_model_name:
try:
self.source_qa_model = AutoModelForQuestionAnswering.from_pretrained(source_qa_model_name)
self.source_qa_tokenizer = AutoTokenizer.from_pretrained(source_qa_model_name)
self.has_source_predictor = True
except:
print("Warning: Source QA model not available")
self.has_source_predictor = False
else:
self.has_source_predictor = False
# Dropout
self.dropout = nn.Dropout(drop_ratio)
# Classification layers
self.SPV_linear = nn.Linear(self.config.hidden_size * 2, classifier_hidden)
self.MIP_linear = nn.Linear(self.config.hidden_size * 2, classifier_hidden)
self.classifier = nn.Linear(classifier_hidden * 2, num_labels)
self._init_weights(self.SPV_linear)
self._init_weights(self.MIP_linear)
self._init_weights(self.classifier)
self.logsoftmax = nn.LogSoftmax(dim=1)
# Load source and frame labels
self.source_id2label = {}
self.frame_id2label = {}
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def predict_frames(self, sentence: str, target_word: str) -> Dict[str, any]:
"""
Predict semantic frame for a target word in context
Args:
sentence: Input sentence
target_word: Target word to analyze
Returns:
Dictionary with frame prediction and confidence
"""
if not self.has_frame_predictor:
return {"frame": "UNKNOWN", "confidence": 0.0}
inputs = self.frame_qa_tokenizer(
sentence,
target_word,
max_length=150,
padding='max_length',
truncation=True,
return_tensors='pt'
)
with torch.no_grad():
outputs = self.frame_qa_model(**inputs)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
start_idx = torch.argmax(start_logits)
end_idx = torch.argmax(end_logits)
confidence = (torch.max(torch.softmax(start_logits, dim=-1)) +
torch.max(torch.softmax(end_logits, dim=-1))) / 2.0
frame_tokens = inputs['input_ids'][0][start_idx:end_idx+1]
frame = self.frame_qa_tokenizer.decode(frame_tokens, skip_special_tokens=True)
return {
"frame": frame if frame else "UNKNOWN",
"confidence": confidence.item()
}
def predict_source(self, sentence: str, target_word: str) -> Dict[str, any]:
"""
Predict source domain for a metaphor
Args:
sentence: Input sentence
target_word: Target word to analyze
Returns:
Dictionary with source prediction and confidence
"""
if not self.has_source_predictor:
return {"source": "UNKNOWN", "confidence": 0.0}
inputs = self.source_qa_tokenizer(
sentence,
target_word,
max_length=150,
padding='max_length',
truncation=True,
return_tensors='pt'
)
with torch.no_grad():
outputs = self.source_qa_model(**inputs)
logits = outputs.logits if hasattr(outputs, 'logits') else outputs.start_logits
probs = torch.softmax(logits, dim=-1)
predicted_id = torch.argmax(probs, dim=-1)
confidence = probs.gather(-1, predicted_id.unsqueeze(-1)).squeeze(-1)
source = self.source_id2label.get(predicted_id.item(), "UNKNOWN")
return {
"source": source,
"confidence": confidence.item()
}
def predict_metaphor(
self,
sentence: str,
target_word: str,
target_positions: Optional[List[int]] = None
) -> Dict[str, any]:
"""
Predict if target word is metaphorical in context
Args:
sentence: Input sentence
target_word: Target word to analyze
target_positions: Token positions of target word (optional)
Returns:
Dictionary with metaphor prediction, frame, and source
"""
# Tokenize input
inputs = self.tokenizer(
sentence,
max_length=150,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# Create target mask
if target_positions is None:
# Find target word positions
target_tokens = self.tokenizer.tokenize(target_word)
sentence_tokens = self.tokenizer.tokenize(sentence)
target_positions = []
for i in range(len(sentence_tokens) - len(target_tokens) + 1):
if sentence_tokens[i:i+len(target_tokens)] == target_tokens:
target_positions = list(range(i+1, i+1+len(target_tokens))) # +1 for CLS token
break
target_mask = torch.zeros_like(inputs['input_ids'], dtype=torch.float)
if target_positions:
for pos in target_positions:
if pos < target_mask.size(1):
target_mask[0, pos] = 1.0
# Forward pass for metaphor detection
with torch.no_grad():
outputs = self.encoder(**inputs)
sequence_output = outputs[0]
pooled_output = outputs[1]
# Get target output
target_output = sequence_output * target_mask.unsqueeze(2)
target_output = target_output.sum(dim=1) / (target_mask.sum(-1, keepdim=True) + 1e-10)
target_output = self.dropout(target_output)
pooled_output = self.dropout(pooled_output)
# SPV and MIP
SPV_hidden = self.SPV_linear(torch.cat([pooled_output, target_output], dim=1))
MIP_hidden = self.MIP_linear(torch.cat([target_output, target_output], dim=1))
# Classification
logits = self.classifier(torch.cat([SPV_hidden, MIP_hidden], dim=1))
logits = self.logsoftmax(logits)
probs = torch.exp(logits)
is_metaphor = torch.argmax(probs, dim=1).item() == 1
metaphor_confidence = probs[0, 1].item()
# Predict frame and source
frame_result = self.predict_frames(sentence, target_word)
source_result = self.predict_source(sentence, target_word) if is_metaphor else {"source": "N/A", "confidence": 0.0}
return {
"is_metaphor": is_metaphor,
"metaphor_confidence": metaphor_confidence,
"frame": frame_result["frame"],
"frame_confidence": frame_result["confidence"],
"source": source_result["source"],
"source_confidence": source_result["confidence"]
}
@classmethod
def from_pretrained(cls, model_path, **kwargs):
"""Load model from pretrained checkpoint"""
# Load config
config_path = os.path.join(model_path, "config.json")
with open(config_path, 'r') as f:
config = json.load(f)
# Initialize model
model = cls(**kwargs)
# Load weights
weights_path = os.path.join(model_path, "pytorch_model.bin")
if os.path.exists(weights_path):
state_dict = torch.load(weights_path, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
return model
def save_pretrained(self, save_directory):
"""Save model to directory"""
os.makedirs(save_directory, exist_ok=True)
# Save weights
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
# Save config
config = {
"_name_or_path": "ConceptFrameMet",
"architectures": ["ConceptFrameMetForMetaphorDetection"],
"model_type": "conceptframemet",
"num_labels": self.num_labels,
"classifier_hidden": self.classifier_hidden,
"drop_ratio": self.drop_ratio,
"source_blend_mode": self.source_blend_mode,
"source_use_mode": self.source_use_mode,
"source_alpha": self.source_alpha,
"metaphor_threshold": self.metaphor_threshold,
}
with open(os.path.join(save_directory, "config.json"), 'w') as f:
json.dump(config, f, indent=2)
# Save tokenizer
self.tokenizer.save_pretrained(save_directory)