Khriis commited on
Commit
aaa853a
·
verified ·
1 Parent(s): 1b95ed9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +108 -27
handler.py CHANGED
@@ -22,7 +22,6 @@ class EndpointHandler:
22
  if not cuda_available:
23
  logger.warning("GPU not detected. Running on CPU. Inference will be slower.")
24
 
25
- # In 'pipeline', device is an integer (-1 for CPU, 0+ for GPU)
26
  self.device_id = 0 if cuda_available else -1
27
 
28
  # Determine model path
@@ -30,7 +29,6 @@ class EndpointHandler:
30
  logger.info(f"Loading model from {model_path}...")
31
 
32
  try:
33
- # Load tokenizer and model explicitly to ensure correct loading
34
  tokenizer = AutoTokenizer.from_pretrained(model_path)
35
  model, loading_info = AutoModelForQuestionAnswering.from_pretrained(
36
  model_path,
@@ -43,8 +41,6 @@ class EndpointHandler:
43
  logger.warning("Loaded model class: %s", model.__class__.__name__)
44
  logger.warning("Loaded model name_or_path: %s", getattr(model.config, "_name_or_path", None))
45
 
46
- # Initialize the pipeline
47
- # top_k=20 matches your previous 'n_best_size=20' logic
48
  self.pipe = pipeline(
49
  "question-answering",
50
  model=model,
@@ -53,6 +49,12 @@ class EndpointHandler:
53
  top_k=20,
54
  handle_impossible_answer=False
55
  )
 
 
 
 
 
 
56
  logger.info("Model loaded successfully.")
57
  except Exception as e:
58
  logger.error(f"Failed to load model: {e}")
@@ -63,11 +65,100 @@ class EndpointHandler:
63
  "Extract the exact short phrase (<= 8 words) from the target "
64
  "utterance that most strongly signals the emotion {emotion}. "
65
  "Return only a substring of the target utterance."
66
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
69
  """
70
  Process inference request.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  """
72
  # Extract inputs
73
  inputs = data.pop("inputs", data)
@@ -86,18 +177,21 @@ class EndpointHandler:
86
  for i, item in enumerate(inputs):
87
  utterance = item.get("utterance", "").strip()
88
  emotion = item.get("emotion", "")
 
89
 
90
  if not utterance:
91
  logger.warning(f"Empty utterance at index {i}")
92
  continue
93
 
 
 
 
94
  # Format as QA task
95
  question = self.question_template.format(emotion=emotion)
96
 
97
- # The pipeline expects a list of dicts with 'question' and 'context'
98
  pipeline_inputs.append({
99
  'question': question,
100
- 'context': utterance
101
  })
102
  valid_indices.append(i)
103
 
@@ -105,7 +199,6 @@ class EndpointHandler:
105
  results = []
106
 
107
  if not pipeline_inputs:
108
- # All inputs were invalid
109
  for item in inputs:
110
  results.append({
111
  "utterance": item.get("utterance", ""),
@@ -116,21 +209,13 @@ class EndpointHandler:
116
  return results
117
 
118
  try:
119
- # Run inference (batch_size helps with multiple inputs)
120
  predictions = self.pipe(pipeline_inputs, batch_size=8)
121
 
122
- # If batch_size=1 or single input, pipeline might return a single list/dict
123
- # We ensure it's a list of lists (since top_k > 1)
124
- if isinstance(predictions, dict): # Single input result
125
- predictions = [predictions] # Wrap in list
126
  elif isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict):
127
- # This happens if we have multiple inputs but top_k=1 (which is not the case here),
128
- # OR if we have a single input and top_k > 1.
129
- # If we have multiple inputs and top_k > 1, it returns a list of lists.
130
- if len(pipeline_inputs) == 1:
131
- predictions = [predictions]
132
- # If multiple inputs and list of dicts, that implies top_k=1.
133
- # But we set top_k=20. So it should be list of lists.
134
 
135
  logger.debug(f"Raw predictions: {predictions}")
136
 
@@ -148,18 +233,14 @@ class EndpointHandler:
148
  "triggers": []
149
  })
150
  else:
151
- # Get prediction for this item
152
- # Because top_k=20, 'current_preds' is a list of dicts: [{'answer': '...', 'score': ...}, ...]
153
  current_preds = predictions[pred_idx]
154
-
155
 
156
- # Ensure it is a list
157
  if isinstance(current_preds, dict):
158
  current_preds = [current_preds]
159
 
160
  logger.info(
161
  "RECCON raw spans (answer, score): %s",
162
- [(p.get("answer"), p.get("score", 0.0), 3) for p in current_preds[:5]]
163
  )
164
 
165
  def is_good_span(ans: str) -> bool:
@@ -168,16 +249,16 @@ class EndpointHandler:
168
  a = ans.strip()
169
  if len(a) < 3:
170
  return False
171
- # reject pure punctuation
172
  if all(ch in ".,!?;:-—'\"()[]{}" for ch in a):
173
  return False
174
- # require at least one letter
175
  if not any(ch.isalpha() for ch in a):
176
  return False
177
  return True
178
 
179
  raw_answers = [p.get("answer", "") for p in current_preds]
180
  raw_answers = [a for a in raw_answers if is_good_span(a)]
 
 
181
  triggers = self._clean_spans(raw_answers, utterance)
182
 
183
  results.append({
 
22
  if not cuda_available:
23
  logger.warning("GPU not detected. Running on CPU. Inference will be slower.")
24
 
 
25
  self.device_id = 0 if cuda_available else -1
26
 
27
  # Determine model path
 
29
  logger.info(f"Loading model from {model_path}...")
30
 
31
  try:
 
32
  tokenizer = AutoTokenizer.from_pretrained(model_path)
33
  model, loading_info = AutoModelForQuestionAnswering.from_pretrained(
34
  model_path,
 
41
  logger.warning("Loaded model class: %s", model.__class__.__name__)
42
  logger.warning("Loaded model name_or_path: %s", getattr(model.config, "_name_or_path", None))
43
 
 
 
44
  self.pipe = pipeline(
45
  "question-answering",
46
  model=model,
 
49
  top_k=20,
50
  handle_impossible_answer=False
51
  )
52
+
53
+ # Store tokenizer for context window management
54
+ self.tokenizer = tokenizer
55
+ # Set max context length (adjust based on your model's max_position_embeddings)
56
+ self.max_context_tokens = 384 # Conservative limit for BERT-based models
57
+
58
  logger.info("Model loaded successfully.")
59
  except Exception as e:
60
  logger.error(f"Failed to load model: {e}")
 
65
  "Extract the exact short phrase (<= 8 words) from the target "
66
  "utterance that most strongly signals the emotion {emotion}. "
67
  "Return only a substring of the target utterance."
68
+ )
69
+
70
+ def _build_context(self, target_utterance: str, conversation_history: List[Dict[str, str]],
71
+ max_history: int = 5) -> str:
72
+ """
73
+ Build conversational context by prepending previous utterances.
74
+
75
+ Args:
76
+ target_utterance: The main utterance to analyze
77
+ conversation_history: List of previous utterances, each with 'speaker' and 'text'
78
+ Format: [{"speaker": "A", "text": "..."}, ...]
79
+ max_history: Maximum number of previous turns to include
80
+
81
+ Returns:
82
+ Formatted context string
83
+ """
84
+ if not conversation_history:
85
+ return target_utterance
86
+
87
+ # Take the most recent turns (up to max_history)
88
+ recent_history = conversation_history[-max_history:] if len(conversation_history) > max_history else conversation_history
89
+
90
+ # Build context string
91
+ context_parts = []
92
+ for turn in recent_history:
93
+ speaker = turn.get("speaker", "")
94
+ text = turn.get("text", "").strip()
95
+ if text:
96
+ if speaker:
97
+ context_parts.append(f"{speaker}: {text}")
98
+ else:
99
+ context_parts.append(text)
100
+
101
+ # Add separator before target utterance
102
+ context_parts.append(f"[TARGET] {target_utterance}")
103
+
104
+ full_context = " ".join(context_parts)
105
+
106
+ # Token-based truncation to fit within model limits
107
+ return self._truncate_context(full_context, target_utterance)
108
+
109
+ def _truncate_context(self, full_context: str, target_utterance: str) -> str:
110
+ """
111
+ Truncate context to fit within token limits while preserving target utterance.
112
+ """
113
+ # Tokenize to check length
114
+ tokens = self.tokenizer.encode(full_context, add_special_tokens=True)
115
+
116
+ if len(tokens) <= self.max_context_tokens:
117
+ return full_context
118
+
119
+ # If too long, ensure target utterance is fully preserved
120
+ # and truncate from the beginning of the context
121
+ target_marker = "[TARGET]"
122
+ if target_marker in full_context:
123
+ parts = full_context.split(target_marker)
124
+ if len(parts) == 2:
125
+ prefix, target_part = parts
126
+ target_with_marker = f"{target_marker} {target_part}"
127
+
128
+ # Calculate tokens for target
129
+ target_tokens = self.tokenizer.encode(target_with_marker, add_special_tokens=False)
130
+ available_for_prefix = self.max_context_tokens - len(target_tokens) - 10 # Buffer for special tokens
131
+
132
+ if available_for_prefix > 0:
133
+ # Truncate prefix from the left (keep most recent context)
134
+ prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
135
+ if len(prefix_tokens) > available_for_prefix:
136
+ prefix_tokens = prefix_tokens[-available_for_prefix:]
137
+ prefix = self.tokenizer.decode(prefix_tokens, skip_special_tokens=True)
138
+
139
+ return f"{prefix} {target_with_marker}"
140
+
141
+ # Fallback: just return target utterance
142
+ logger.warning("Context truncation fallback - returning target only")
143
+ return target_utterance
144
 
145
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
146
  """
147
  Process inference request.
148
+
149
+ Expected input format (NEW):
150
+ {
151
+ "inputs": [
152
+ {
153
+ "utterance": "I'm so happy today!",
154
+ "emotion": "joy",
155
+ "conversation_history": [ # OPTIONAL
156
+ {"speaker": "A", "text": "How are you doing?"},
157
+ {"speaker": "B", "text": "Pretty good, thanks!"}
158
+ ]
159
+ }
160
+ ]
161
+ }
162
  """
163
  # Extract inputs
164
  inputs = data.pop("inputs", data)
 
177
  for i, item in enumerate(inputs):
178
  utterance = item.get("utterance", "").strip()
179
  emotion = item.get("emotion", "")
180
+ conversation_history = item.get("conversation_history", [])
181
 
182
  if not utterance:
183
  logger.warning(f"Empty utterance at index {i}")
184
  continue
185
 
186
+ # Build context with conversation history
187
+ context = self._build_context(utterance, conversation_history)
188
+
189
  # Format as QA task
190
  question = self.question_template.format(emotion=emotion)
191
 
 
192
  pipeline_inputs.append({
193
  'question': question,
194
+ 'context': context # Now includes conversation history
195
  })
196
  valid_indices.append(i)
197
 
 
199
  results = []
200
 
201
  if not pipeline_inputs:
 
202
  for item in inputs:
203
  results.append({
204
  "utterance": item.get("utterance", ""),
 
209
  return results
210
 
211
  try:
 
212
  predictions = self.pipe(pipeline_inputs, batch_size=8)
213
 
214
+ if isinstance(predictions, dict):
215
+ predictions = [predictions]
 
 
216
  elif isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict):
217
+ if len(pipeline_inputs) == 1:
218
+ predictions = [predictions]
 
 
 
 
 
219
 
220
  logger.debug(f"Raw predictions: {predictions}")
221
 
 
233
  "triggers": []
234
  })
235
  else:
 
 
236
  current_preds = predictions[pred_idx]
 
237
 
 
238
  if isinstance(current_preds, dict):
239
  current_preds = [current_preds]
240
 
241
  logger.info(
242
  "RECCON raw spans (answer, score): %s",
243
+ [(p.get("answer"), p.get("score", 0.0)) for p in current_preds[:5]]
244
  )
245
 
246
  def is_good_span(ans: str) -> bool:
 
249
  a = ans.strip()
250
  if len(a) < 3:
251
  return False
 
252
  if all(ch in ".,!?;:-—'\"()[]{}" for ch in a):
253
  return False
 
254
  if not any(ch.isalpha() for ch in a):
255
  return False
256
  return True
257
 
258
  raw_answers = [p.get("answer", "") for p in current_preds]
259
  raw_answers = [a for a in raw_answers if is_good_span(a)]
260
+
261
+ # Clean spans against ORIGINAL utterance (not full context)
262
  triggers = self._clean_spans(raw_answers, utterance)
263
 
264
  results.append({