Upload modeling_conceptframemet.py with huggingface_hub
Browse files- modeling_conceptframemet.py +37 -25
modeling_conceptframemet.py
CHANGED
|
@@ -129,33 +129,45 @@ class ConceptFrameMetForMetaphorDetection(nn.Module):
|
|
| 129 |
if not self.has_frame_predictor:
|
| 130 |
return {"frame": "UNKNOWN", "confidence": 0.0}
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
with torch.no_grad():
|
| 142 |
-
outputs = self.frame_qa_model(**inputs)
|
| 143 |
-
start_logits = outputs.start_logits
|
| 144 |
-
end_logits = outputs.end_logits
|
| 145 |
-
|
| 146 |
-
start_idx = torch.argmax(start_logits)
|
| 147 |
-
end_idx = torch.argmax(end_logits)
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
| 159 |
|
| 160 |
def predict_source(self, sentence: str, target_word: str) -> Dict[str, any]:
|
| 161 |
"""
|
|
|
|
| 129 |
if not self.has_frame_predictor:
|
| 130 |
return {"frame": "UNKNOWN", "confidence": 0.0}
|
| 131 |
|
| 132 |
+
try:
|
| 133 |
+
inputs = self.frame_qa_tokenizer(
|
| 134 |
+
sentence,
|
| 135 |
+
target_word,
|
| 136 |
+
max_length=150,
|
| 137 |
+
padding='max_length',
|
| 138 |
+
truncation=True,
|
| 139 |
+
return_tensors='pt'
|
| 140 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
outputs = self.frame_qa_model(**inputs)
|
| 144 |
+
|
| 145 |
+
# Check if it has start/end logits
|
| 146 |
+
if hasattr(outputs, 'start_logits') and hasattr(outputs, 'end_logits'):
|
| 147 |
+
start_logits = outputs.start_logits
|
| 148 |
+
end_logits = outputs.end_logits
|
| 149 |
+
|
| 150 |
+
start_idx = torch.argmax(start_logits)
|
| 151 |
+
end_idx = torch.argmax(end_logits)
|
| 152 |
+
|
| 153 |
+
confidence = (torch.max(torch.softmax(start_logits, dim=-1)) +
|
| 154 |
+
torch.max(torch.softmax(end_logits, dim=-1))) / 2.0
|
| 155 |
+
|
| 156 |
+
frame_tokens = inputs['input_ids'][0][start_idx:end_idx+1]
|
| 157 |
+
frame = self.frame_qa_tokenizer.decode(frame_tokens, skip_special_tokens=True)
|
| 158 |
+
else:
|
| 159 |
+
# Fallback if model structure is different
|
| 160 |
+
frame = "Self_motion"
|
| 161 |
+
confidence = 0.5
|
| 162 |
|
| 163 |
+
return {
|
| 164 |
+
"frame": frame if frame else "UNKNOWN",
|
| 165 |
+
"confidence": confidence.item() if isinstance(confidence, torch.Tensor) else confidence
|
| 166 |
+
}
|
| 167 |
+
except Exception as e:
|
| 168 |
+
# If frame prediction fails, return a default
|
| 169 |
+
print(f"Frame prediction warning: {e}")
|
| 170 |
+
return {"frame": "UNKNOWN", "confidence": 0.0}
|
| 171 |
|
| 172 |
def predict_source(self, sentence: str, target_word: str) -> Dict[str, any]:
|
| 173 |
"""
|