Khriis commited on
Commit
34873e0
·
verified ·
1 Parent(s): 3a7a32b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +37 -170
handler.py CHANGED
@@ -22,6 +22,7 @@ class EndpointHandler:
22
  if not cuda_available:
23
  logger.warning("GPU not detected. Running on CPU. Inference will be slower.")
24
 
 
25
  self.device_id = 0 if cuda_available else -1
26
 
27
  # Determine model path
@@ -29,6 +30,7 @@ class EndpointHandler:
29
  logger.info(f"Loading model from {model_path}...")
30
 
31
  try:
 
32
  tokenizer = AutoTokenizer.from_pretrained(model_path)
33
  model, loading_info = AutoModelForQuestionAnswering.from_pretrained(
34
  model_path,
@@ -41,6 +43,8 @@ class EndpointHandler:
41
  logger.warning("Loaded model class: %s", model.__class__.__name__)
42
  logger.warning("Loaded model name_or_path: %s", getattr(model.config, "_name_or_path", None))
43
 
 
 
44
  self.pipe = pipeline(
45
  "question-answering",
46
  model=model,
@@ -49,12 +53,6 @@ class EndpointHandler:
49
  top_k=20,
50
  handle_impossible_answer=False
51
  )
52
-
53
- # Store tokenizer for context window management
54
- self.tokenizer = tokenizer
55
- # Set max context length (adjust based on your model's max_position_embeddings)
56
- self.max_context_tokens = 384 # Conservative limit for BERT-based models
57
-
58
  logger.info("Model loaded successfully.")
59
  except Exception as e:
60
  logger.error(f"Failed to load model: {e}")
@@ -65,100 +63,11 @@ class EndpointHandler:
65
  "Extract the exact short phrase (<= 8 words) from the target "
66
  "utterance that most strongly signals the emotion {emotion}. "
67
  "Return only a substring of the target utterance."
68
- )
69
-
70
- def _build_context(self, target_utterance: str, conversation_history: List[Dict[str, str]],
71
- max_history: int = 5) -> str:
72
- """
73
- Build conversational context by prepending previous utterances.
74
-
75
- Args:
76
- target_utterance: The main utterance to analyze
77
- conversation_history: List of previous utterances, each with 'speaker' and 'text'
78
- Format: [{"speaker": "A", "text": "..."}, ...]
79
- max_history: Maximum number of previous turns to include
80
-
81
- Returns:
82
- Formatted context string
83
- """
84
- if not conversation_history:
85
- return target_utterance
86
-
87
- # Take the most recent turns (up to max_history)
88
- recent_history = conversation_history[-max_history:] if len(conversation_history) > max_history else conversation_history
89
-
90
- # Build context string
91
- context_parts = []
92
- for turn in recent_history:
93
- speaker = turn.get("speaker", "")
94
- text = turn.get("text", "").strip()
95
- if text:
96
- if speaker:
97
- context_parts.append(f"{speaker}: {text}")
98
- else:
99
- context_parts.append(text)
100
-
101
- # Add separator before target utterance
102
- context_parts.append(f"[TARGET] {target_utterance}")
103
-
104
- full_context = " ".join(context_parts)
105
-
106
- # Token-based truncation to fit within model limits
107
- return self._truncate_context(full_context, target_utterance)
108
-
109
- def _truncate_context(self, full_context: str, target_utterance: str) -> str:
110
- """
111
- Truncate context to fit within token limits while preserving target utterance.
112
- """
113
- # Tokenize to check length
114
- tokens = self.tokenizer.encode(full_context, add_special_tokens=True)
115
-
116
- if len(tokens) <= self.max_context_tokens:
117
- return full_context
118
-
119
- # If too long, ensure target utterance is fully preserved
120
- # and truncate from the beginning of the context
121
- target_marker = "[TARGET]"
122
- if target_marker in full_context:
123
- parts = full_context.split(target_marker)
124
- if len(parts) == 2:
125
- prefix, target_part = parts
126
- target_with_marker = f"{target_marker} {target_part}"
127
-
128
- # Calculate tokens for target
129
- target_tokens = self.tokenizer.encode(target_with_marker, add_special_tokens=False)
130
- available_for_prefix = self.max_context_tokens - len(target_tokens) - 10 # Buffer for special tokens
131
-
132
- if available_for_prefix > 0:
133
- # Truncate prefix from the left (keep most recent context)
134
- prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
135
- if len(prefix_tokens) > available_for_prefix:
136
- prefix_tokens = prefix_tokens[-available_for_prefix:]
137
- prefix = self.tokenizer.decode(prefix_tokens, skip_special_tokens=True)
138
-
139
- return f"{prefix} {target_with_marker}"
140
-
141
- # Fallback: just return target utterance
142
- logger.warning("Context truncation fallback - returning target only")
143
- return target_utterance
144
 
145
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
146
  """
147
  Process inference request.
148
-
149
- Expected input format (NEW):
150
- {
151
- "inputs": [
152
- {
153
- "utterance": "I'm so happy today!",
154
- "emotion": "joy",
155
- "conversation_history": [ # OPTIONAL
156
- {"speaker": "A", "text": "How are you doing?"},
157
- {"speaker": "B", "text": "Pretty good, thanks!"}
158
- ]
159
- }
160
- ]
161
- }
162
  """
163
  # Extract inputs
164
  inputs = data.pop("inputs", data)
@@ -173,40 +82,30 @@ class EndpointHandler:
173
  # Validate and format inputs for the pipeline
174
  pipeline_inputs = []
175
  valid_indices = []
176
- contexts = [] # Store contexts for later use in cleaning
177
 
178
  for i, item in enumerate(inputs):
179
  utterance = item.get("utterance", "").strip()
180
  emotion = item.get("emotion", "")
181
- conversation_history = item.get("conversation_history", [])
182
-
183
- # Log input details
184
- logger.info(f"Turn {i}: utterance='{utterance[:50]}...', emotion={emotion}, history_len={len(conversation_history)}")
185
- if conversation_history:
186
- logger.info(f" History: {conversation_history}")
187
 
188
  if not utterance:
189
  logger.warning(f"Empty utterance at index {i}")
190
  continue
191
 
192
- # Build context with conversation history
193
- context = self._build_context(utterance, conversation_history)
194
- logger.info(f"Built context for turn {i}: '{context}'")
195
-
196
  # Format as QA task
197
  question = self.question_template.format(emotion=emotion)
198
 
 
199
  pipeline_inputs.append({
200
  'question': question,
201
- 'context': context # Now includes conversation history
202
  })
203
  valid_indices.append(i)
204
- contexts.append(context) # Store for later use
205
 
206
  # Run prediction
207
  results = []
208
 
209
  if not pipeline_inputs:
 
210
  for item in inputs:
211
  results.append({
212
  "utterance": item.get("utterance", ""),
@@ -217,13 +116,21 @@ class EndpointHandler:
217
  return results
218
 
219
  try:
 
220
  predictions = self.pipe(pipeline_inputs, batch_size=8)
221
 
222
- if isinstance(predictions, dict):
223
- predictions = [predictions]
 
 
224
  elif isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict):
225
- if len(pipeline_inputs) == 1:
226
- predictions = [predictions]
 
 
 
 
 
227
 
228
  logger.debug(f"Raw predictions: {predictions}")
229
 
@@ -241,16 +148,19 @@ class EndpointHandler:
241
  "triggers": []
242
  })
243
  else:
 
 
244
  current_preds = predictions[pred_idx]
 
245
 
 
246
  if isinstance(current_preds, dict):
247
  current_preds = [current_preds]
248
 
249
  logger.info(
250
  "RECCON raw spans (answer, score): %s",
251
- [(p.get("answer"), p.get("score", 0.0)) for p in current_preds[:5]]
252
  )
253
- logger.info(f"Total predictions received: {len(current_preds)}")
254
 
255
  def is_good_span(ans: str) -> bool:
256
  if not ans:
@@ -258,30 +168,17 @@ class EndpointHandler:
258
  a = ans.strip()
259
  if len(a) < 3:
260
  return False
 
261
  if all(ch in ".,!?;:-—'\"()[]{}" for ch in a):
262
  return False
 
263
  if not any(ch.isalpha() for ch in a):
264
  return False
265
- # Filter out speaker labels and prompt artifacts
266
- a_lower = a.lower()
267
- if "patient:" in a_lower or "therapist:" in a_lower or "[target]" in a_lower:
268
- return False
269
- if a_lower in ["patient", "therapist"]:
270
- return False
271
  return True
272
 
273
  raw_answers = [p.get("answer", "") for p in current_preds]
274
- logger.info(f"Raw answers before filtering: {raw_answers}")
275
-
276
  raw_answers = [a for a in raw_answers if is_good_span(a)]
277
- logger.info(f"Answers after is_good_span filter: {raw_answers}")
278
-
279
- # Extract context text (part before [TARGET] marker)
280
- full_context = contexts[pred_idx]
281
- context_without_target = full_context.split("[TARGET]")[0].strip() if "[TARGET]" in full_context else ""
282
-
283
- # Clean spans against BOTH target utterance AND context
284
- triggers = self._clean_spans(raw_answers, utterance, context_without_target)
285
 
286
  results.append({
287
  "utterance": utterance,
@@ -302,19 +199,12 @@ class EndpointHandler:
302
  "triggers": []
303
  } for item in inputs]
304
 
305
- def _clean_spans(self, spans: List[str], target_text: str, context_text: str = "") -> List[str]:
306
  """
307
  Clean and filter extracted trigger spans.
308
- Spans can come from either target_text or context_text.
309
  """
310
- logger.info(f"_clean_spans called with {len(spans)} spans")
311
- logger.info(f" Target: '{target_text}'")
312
- logger.info(f" Context: '{context_text[:100]}...'" if len(context_text) > 100 else f" Context: '{context_text}'")
313
- logger.info(f" Input spans: {spans}")
314
-
315
  target_text = target_text or ""
316
- context_text = context_text or ""
317
- full_text = (context_text + " " + target_text).lower()
318
  target_lower = target_text.lower()
319
 
320
  def _norm(s: str) -> str:
@@ -323,11 +213,10 @@ class EndpointHandler:
323
  s = re.sub(r"^[^\w]+|[^\w]+$", "", s)
324
  return s
325
 
326
- def _extract_from_text(text: str, phrase_lower: str) -> str:
327
- """Extract phrase from text preserving original case."""
328
- idx = text.lower().find(phrase_lower)
329
  if idx >= 0:
330
- return text[idx:idx+len(phrase_lower)]
331
  return phrase_lower
332
 
333
  STOP = {
@@ -346,8 +235,7 @@ class EndpointHandler:
346
  s_norm = _norm(s)
347
  if not s_norm:
348
  continue
349
- # Check if span exists in EITHER target OR context
350
- if full_text and s_norm not in full_text:
351
  continue
352
  tokens = s_norm.split()
353
  if len(tokens) > 8 or len(s_norm) > 80:
@@ -362,35 +250,14 @@ class EndpointHandler:
362
  })
363
 
364
  candidates.sort(key=lambda x: (x["tok_len"], x["char_len"]), reverse=True)
365
- logger.info(f"Built {len(candidates)} candidates: {[c['norm'] for c in candidates]}")
366
-
367
  kept_norms = []
368
  for c in list(candidates):
369
  n = c["norm"]
370
  if any(n in kn or kn in n for kn in kept_norms):
371
  continue
372
  kept_norms.append(n)
373
-
374
- logger.info(f"After dedup: {kept_norms}")
375
-
376
- # Extract spans from either target or context (whichever contains them)
377
- cleaned = []
378
- for n in kept_norms:
379
- # Try target first, then context
380
- if n in target_lower:
381
- extracted = _extract_from_text(target_text, n)
382
- logger.info(f" Extracted '{extracted}' from TARGET")
383
- cleaned.append(extracted)
384
- elif n in context_text.lower():
385
- extracted = _extract_from_text(context_text, n)
386
- logger.info(f" Extracted '{extracted}' from CONTEXT")
387
- cleaned.append(extracted)
388
- else:
389
- # Fallback - shouldn't happen given earlier validation
390
- logger.warning(f" Phrase '{n}' not found in target or context, using normalized")
391
- cleaned.append(n)
392
-
393
- logger.info(f"Final cleaned spans: {cleaned}")
394
 
395
  if not cleaned and spans:
396
  tt_tokens = target_lower.split()
@@ -410,6 +277,6 @@ class EndpointHandler:
410
  if best:
411
  break
412
  if best:
413
- return [_extract_from_text(target_text, best)]
414
 
415
  return cleaned[:3]
 
22
  if not cuda_available:
23
  logger.warning("GPU not detected. Running on CPU. Inference will be slower.")
24
 
25
+ # In 'pipeline', device is an integer (-1 for CPU, 0+ for GPU)
26
  self.device_id = 0 if cuda_available else -1
27
 
28
  # Determine model path
 
30
  logger.info(f"Loading model from {model_path}...")
31
 
32
  try:
33
+ # Load tokenizer and model explicitly to ensure correct loading
34
  tokenizer = AutoTokenizer.from_pretrained(model_path)
35
  model, loading_info = AutoModelForQuestionAnswering.from_pretrained(
36
  model_path,
 
43
  logger.warning("Loaded model class: %s", model.__class__.__name__)
44
  logger.warning("Loaded model name_or_path: %s", getattr(model.config, "_name_or_path", None))
45
 
46
+ # Initialize the pipeline
47
+ # top_k=20 matches your previous 'n_best_size=20' logic
48
  self.pipe = pipeline(
49
  "question-answering",
50
  model=model,
 
53
  top_k=20,
54
  handle_impossible_answer=False
55
  )
 
 
 
 
 
 
56
  logger.info("Model loaded successfully.")
57
  except Exception as e:
58
  logger.error(f"Failed to load model: {e}")
 
63
  "Extract the exact short phrase (<= 8 words) from the target "
64
  "utterance that most strongly signals the emotion {emotion}. "
65
  "Return only a substring of the target utterance."
66
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
69
  """
70
  Process inference request.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  """
72
  # Extract inputs
73
  inputs = data.pop("inputs", data)
 
82
  # Validate and format inputs for the pipeline
83
  pipeline_inputs = []
84
  valid_indices = []
 
85
 
86
  for i, item in enumerate(inputs):
87
  utterance = item.get("utterance", "").strip()
88
  emotion = item.get("emotion", "")
 
 
 
 
 
 
89
 
90
  if not utterance:
91
  logger.warning(f"Empty utterance at index {i}")
92
  continue
93
 
 
 
 
 
94
  # Format as QA task
95
  question = self.question_template.format(emotion=emotion)
96
 
97
+ # The pipeline expects a list of dicts with 'question' and 'context'
98
  pipeline_inputs.append({
99
  'question': question,
100
+ 'context': utterance
101
  })
102
  valid_indices.append(i)
 
103
 
104
  # Run prediction
105
  results = []
106
 
107
  if not pipeline_inputs:
108
+ # All inputs were invalid
109
  for item in inputs:
110
  results.append({
111
  "utterance": item.get("utterance", ""),
 
116
  return results
117
 
118
  try:
119
+ # Run inference (batch_size helps with multiple inputs)
120
  predictions = self.pipe(pipeline_inputs, batch_size=8)
121
 
122
+ # If batch_size=1 or single input, pipeline might return a single list/dict
123
+ # We ensure it's a list of lists (since top_k > 1)
124
+ if isinstance(predictions, dict): # Single input result
125
+ predictions = [predictions] # Wrap in list
126
  elif isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict):
127
+ # This happens if we have multiple inputs but top_k=1 (which is not the case here),
128
+ # OR if we have a single input and top_k > 1.
129
+ # If we have multiple inputs and top_k > 1, it returns a list of lists.
130
+ if len(pipeline_inputs) == 1:
131
+ predictions = [predictions]
132
+ # If multiple inputs and list of dicts, that implies top_k=1.
133
+ # But we set top_k=20. So it should be list of lists.
134
 
135
  logger.debug(f"Raw predictions: {predictions}")
136
 
 
148
  "triggers": []
149
  })
150
  else:
151
+ # Get prediction for this item
152
+ # Because top_k=20, 'current_preds' is a list of dicts: [{'answer': '...', 'score': ...}, ...]
153
  current_preds = predictions[pred_idx]
154
+
155
 
156
+ # Ensure it is a list
157
  if isinstance(current_preds, dict):
158
  current_preds = [current_preds]
159
 
160
  logger.info(
161
  "RECCON raw spans (answer, score): %s",
162
+ [(p.get("answer"), p.get("score", 0.0), 3) for p in current_preds[:5]]
163
  )
 
164
 
165
  def is_good_span(ans: str) -> bool:
166
  if not ans:
 
168
  a = ans.strip()
169
  if len(a) < 3:
170
  return False
171
+ # reject pure punctuation
172
  if all(ch in ".,!?;:-—'\"()[]{}" for ch in a):
173
  return False
174
+ # require at least one letter
175
  if not any(ch.isalpha() for ch in a):
176
  return False
 
 
 
 
 
 
177
  return True
178
 
179
  raw_answers = [p.get("answer", "") for p in current_preds]
 
 
180
  raw_answers = [a for a in raw_answers if is_good_span(a)]
181
+ triggers = self._clean_spans(raw_answers, utterance)
 
 
 
 
 
 
 
182
 
183
  results.append({
184
  "utterance": utterance,
 
199
  "triggers": []
200
  } for item in inputs]
201
 
202
+ def _clean_spans(self, spans: List[str], target_text: str) -> List[str]:
203
  """
204
  Clean and filter extracted trigger spans.
205
+ (Logic preserved exactly as provided)
206
  """
 
 
 
 
 
207
  target_text = target_text or ""
 
 
208
  target_lower = target_text.lower()
209
 
210
  def _norm(s: str) -> str:
 
213
  s = re.sub(r"^[^\w]+|[^\w]+$", "", s)
214
  return s
215
 
216
+ def _extract_from_target(target: str, phrase_lower: str) -> str:
217
+ idx = target.lower().find(phrase_lower)
 
218
  if idx >= 0:
219
+ return target[idx:idx+len(phrase_lower)]
220
  return phrase_lower
221
 
222
  STOP = {
 
235
  s_norm = _norm(s)
236
  if not s_norm:
237
  continue
238
+ if target_text and s_norm not in target_lower:
 
239
  continue
240
  tokens = s_norm.split()
241
  if len(tokens) > 8 or len(s_norm) > 80:
 
250
  })
251
 
252
  candidates.sort(key=lambda x: (x["tok_len"], x["char_len"]), reverse=True)
 
 
253
  kept_norms = []
254
  for c in list(candidates):
255
  n = c["norm"]
256
  if any(n in kn or kn in n for kn in kept_norms):
257
  continue
258
  kept_norms.append(n)
259
+
260
+ cleaned = [_extract_from_target(target_text, n) for n in kept_norms]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  if not cleaned and spans:
263
  tt_tokens = target_lower.split()
 
277
  if best:
278
  break
279
  if best:
280
+ return [_extract_from_target(target_text, best)]
281
 
282
  return cleaned[:3]