Update handler.py
Browse files- handler.py +14 -5
handler.py
CHANGED
|
@@ -173,6 +173,7 @@ class EndpointHandler:
|
|
| 173 |
# Validate and format inputs for the pipeline
|
| 174 |
pipeline_inputs = []
|
| 175 |
valid_indices = []
|
|
|
|
| 176 |
|
| 177 |
for i, item in enumerate(inputs):
|
| 178 |
utterance = item.get("utterance", "").strip()
|
|
@@ -194,6 +195,7 @@ class EndpointHandler:
|
|
| 194 |
'context': context # Now includes conversation history
|
| 195 |
})
|
| 196 |
valid_indices.append(i)
|
|
|
|
| 197 |
|
| 198 |
# Run prediction
|
| 199 |
results = []
|
|
@@ -258,8 +260,12 @@ class EndpointHandler:
|
|
| 258 |
raw_answers = [p.get("answer", "") for p in current_preds]
|
| 259 |
raw_answers = [a for a in raw_answers if is_good_span(a)]
|
| 260 |
|
| 261 |
-
#
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
results.append({
|
| 265 |
"utterance": utterance,
|
|
@@ -280,12 +286,14 @@ class EndpointHandler:
|
|
| 280 |
"triggers": []
|
| 281 |
} for item in inputs]
|
| 282 |
|
| 283 |
-
def _clean_spans(self, spans: List[str], target_text: str) -> List[str]:
|
| 284 |
"""
|
| 285 |
Clean and filter extracted trigger spans.
|
| 286 |
-
|
| 287 |
"""
|
| 288 |
target_text = target_text or ""
|
|
|
|
|
|
|
| 289 |
target_lower = target_text.lower()
|
| 290 |
|
| 291 |
def _norm(s: str) -> str:
|
|
@@ -316,7 +324,8 @@ class EndpointHandler:
|
|
| 316 |
s_norm = _norm(s)
|
| 317 |
if not s_norm:
|
| 318 |
continue
|
| 319 |
-
|
|
|
|
| 320 |
continue
|
| 321 |
tokens = s_norm.split()
|
| 322 |
if len(tokens) > 8 or len(s_norm) > 80:
|
|
|
|
| 173 |
# Validate and format inputs for the pipeline
|
| 174 |
pipeline_inputs = []
|
| 175 |
valid_indices = []
|
| 176 |
+
contexts = [] # Store contexts for later use in cleaning
|
| 177 |
|
| 178 |
for i, item in enumerate(inputs):
|
| 179 |
utterance = item.get("utterance", "").strip()
|
|
|
|
| 195 |
'context': context # Now includes conversation history
|
| 196 |
})
|
| 197 |
valid_indices.append(i)
|
| 198 |
+
contexts.append(context) # Store for later use
|
| 199 |
|
| 200 |
# Run prediction
|
| 201 |
results = []
|
|
|
|
| 260 |
raw_answers = [p.get("answer", "") for p in current_preds]
|
| 261 |
raw_answers = [a for a in raw_answers if is_good_span(a)]
|
| 262 |
|
| 263 |
+
# Extract context text (part before [TARGET] marker)
|
| 264 |
+
full_context = contexts[pred_idx]
|
| 265 |
+
context_without_target = full_context.split("[TARGET]")[0].strip() if "[TARGET]" in full_context else ""
|
| 266 |
+
|
| 267 |
+
# Clean spans against BOTH target utterance AND context
|
| 268 |
+
triggers = self._clean_spans(raw_answers, utterance, context_without_target)
|
| 269 |
|
| 270 |
results.append({
|
| 271 |
"utterance": utterance,
|
|
|
|
| 286 |
"triggers": []
|
| 287 |
} for item in inputs]
|
| 288 |
|
| 289 |
+
def _clean_spans(self, spans: List[str], target_text: str, context_text: str = "") -> List[str]:
|
| 290 |
"""
|
| 291 |
Clean and filter extracted trigger spans.
|
| 292 |
+
Spans can come from either target_text or context_text.
|
| 293 |
"""
|
| 294 |
target_text = target_text or ""
|
| 295 |
+
context_text = context_text or ""
|
| 296 |
+
full_text = (context_text + " " + target_text).lower()
|
| 297 |
target_lower = target_text.lower()
|
| 298 |
|
| 299 |
def _norm(s: str) -> str:
|
|
|
|
| 324 |
s_norm = _norm(s)
|
| 325 |
if not s_norm:
|
| 326 |
continue
|
| 327 |
+
# Check if span exists in EITHER target OR context
|
| 328 |
+
if full_text and s_norm not in full_text:
|
| 329 |
continue
|
| 330 |
tokens = s_norm.split()
|
| 331 |
if len(tokens) > 8 or len(s_norm) > 80:
|