Khriis commited on
Commit
6e7cb9b
verified
1 Parent(s): aaa853a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -5
handler.py CHANGED
@@ -173,6 +173,7 @@ class EndpointHandler:
173
  # Validate and format inputs for the pipeline
174
  pipeline_inputs = []
175
  valid_indices = []
 
176
 
177
  for i, item in enumerate(inputs):
178
  utterance = item.get("utterance", "").strip()
@@ -194,6 +195,7 @@ class EndpointHandler:
194
  'context': context # Now includes conversation history
195
  })
196
  valid_indices.append(i)
 
197
 
198
  # Run prediction
199
  results = []
@@ -258,8 +260,12 @@ class EndpointHandler:
258
  raw_answers = [p.get("answer", "") for p in current_preds]
259
  raw_answers = [a for a in raw_answers if is_good_span(a)]
260
 
261
- # Clean spans against ORIGINAL utterance (not full context)
262
- triggers = self._clean_spans(raw_answers, utterance)
 
 
 
 
263
 
264
  results.append({
265
  "utterance": utterance,
@@ -280,12 +286,14 @@ class EndpointHandler:
280
  "triggers": []
281
  } for item in inputs]
282
 
283
- def _clean_spans(self, spans: List[str], target_text: str) -> List[str]:
284
  """
285
  Clean and filter extracted trigger spans.
286
- (Logic preserved exactly as provided)
287
  """
288
  target_text = target_text or ""
 
 
289
  target_lower = target_text.lower()
290
 
291
  def _norm(s: str) -> str:
@@ -316,7 +324,8 @@ class EndpointHandler:
316
  s_norm = _norm(s)
317
  if not s_norm:
318
  continue
319
- if target_text and s_norm not in target_lower:
 
320
  continue
321
  tokens = s_norm.split()
322
  if len(tokens) > 8 or len(s_norm) > 80:
 
173
  # Validate and format inputs for the pipeline
174
  pipeline_inputs = []
175
  valid_indices = []
176
+ contexts = [] # Store contexts for later use in cleaning
177
 
178
  for i, item in enumerate(inputs):
179
  utterance = item.get("utterance", "").strip()
 
195
  'context': context # Now includes conversation history
196
  })
197
  valid_indices.append(i)
198
+ contexts.append(context) # Store for later use
199
 
200
  # Run prediction
201
  results = []
 
260
  raw_answers = [p.get("answer", "") for p in current_preds]
261
  raw_answers = [a for a in raw_answers if is_good_span(a)]
262
 
263
+ # Extract context text (part before [TARGET] marker)
264
+ full_context = contexts[pred_idx]
265
+ context_without_target = full_context.split("[TARGET]")[0].strip() if "[TARGET]" in full_context else ""
266
+
267
+ # Clean spans against BOTH target utterance AND context
268
+ triggers = self._clean_spans(raw_answers, utterance, context_without_target)
269
 
270
  results.append({
271
  "utterance": utterance,
 
286
  "triggers": []
287
  } for item in inputs]
288
 
289
+ def _clean_spans(self, spans: List[str], target_text: str, context_text: str = "") -> List[str]:
290
  """
291
  Clean and filter extracted trigger spans.
292
+ Spans can come from either target_text or context_text.
293
  """
294
  target_text = target_text or ""
295
+ context_text = context_text or ""
296
+ full_text = (context_text + " " + target_text).lower()
297
  target_lower = target_text.lower()
298
 
299
  def _norm(s: str) -> str:
 
324
  s_norm = _norm(s)
325
  if not s_norm:
326
  continue
327
+ # Check if span exists in EITHER target OR context
328
+ if full_text and s_norm not in full_text:
329
  continue
330
  tokens = s_norm.split()
331
  if len(tokens) > 8 or len(s_norm) > 80: