Khriis commited on
Commit
ca8e0c5
verified
1 Parent(s): 3e79e47

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +62 -94
handler.py CHANGED
@@ -2,18 +2,16 @@ import torch
2
  import logging
3
  import re
4
  from typing import Dict, List, Any
5
- from simpletransformers.question_answering import QuestionAnsweringModel
6
 
7
- # Configure logging (no file I/O for serverless environment)
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
-
12
  class EndpointHandler:
13
  def __init__(self, path=""):
14
  """
15
- Initialize the RECCON emotional trigger extraction model.
16
-
17
  Args:
18
  path: Path to model directory (provided by HuggingFace Inference Endpoints)
19
  """
@@ -23,33 +21,28 @@ class EndpointHandler:
23
  cuda_available = torch.cuda.is_available()
24
  if not cuda_available:
25
  logger.warning("GPU not detected. Running on CPU. Inference will be slower.")
26
- self.device = torch.device("cuda" if cuda_available else "cpu")
27
- cuda_device = 0 if cuda_available else -1
 
28
 
29
  # Determine model path
30
- if not path or path == ".":
31
- model_path = "."
32
- else:
33
- model_path = path
34
-
35
  logger.info(f"Loading model from {model_path}...")
36
 
37
- # Load the QuestionAnsweringModel using simpletransformers
38
  try:
39
- self.model = QuestionAnsweringModel(
40
- "roberta",
41
- model_path,
42
- args={
43
- "silent_tf_logger": True,
44
- "eval_batch_size": 8,
45
- "device_map": None,
46
- "max_seq_length": 512,
47
- "max_answer_length": 200,
48
- "n_best_size": 20,
49
- "doc_stride": 512
50
- },
51
- use_cuda=cuda_available,
52
- cuda_device=cuda_device
53
  )
54
  logger.info("Model loaded successfully.")
55
  except Exception as e:
@@ -66,39 +59,19 @@ class EndpointHandler:
66
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
67
  """
68
  Process inference request.
69
-
70
- Args:
71
- data: Request data with structure:
72
- {
73
- "inputs": [
74
- {"utterance": "text", "emotion": "happiness"},
75
- ...
76
- ]
77
- }
78
-
79
- Returns:
80
- List of results:
81
- [
82
- {
83
- "utterance": "text",
84
- "emotion": "happiness",
85
- "triggers": ["trigger phrase 1", "trigger phrase 2"]
86
- },
87
- ...
88
- ]
89
  """
90
  # Extract inputs
91
  inputs = data.pop("inputs", data)
92
 
93
- # Normalize to list format (handle single dict)
94
  if isinstance(inputs, dict):
95
  inputs = [inputs]
96
 
97
  if not inputs:
98
  return [{"error": "No inputs provided", "triggers": []}]
99
 
100
- # Validate and format inputs
101
- qa_inputs = []
102
  valid_indices = []
103
 
104
  for i, item in enumerate(inputs):
@@ -111,19 +84,18 @@ class EndpointHandler:
111
 
112
  # Format as QA task
113
  question = self.question_template.format(emotion=emotion)
114
- qa_inputs.append({
115
- 'context': utterance,
116
- 'qas': [{
117
- 'id': f'temp_id_{i}',
118
- 'question': question
119
- }]
120
  })
121
  valid_indices.append(i)
122
 
123
  # Run prediction
124
  results = []
125
 
126
- if not qa_inputs:
127
  # All inputs were invalid
128
  for item in inputs:
129
  results.append({
@@ -135,17 +107,31 @@ class EndpointHandler:
135
  return results
136
 
137
  try:
138
- predictions, _ = self.model.predict(qa_inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  logger.debug(f"Raw predictions: {predictions}")
140
 
141
  # Post-process results
142
- result_idx = 0
143
  for i, item in enumerate(inputs):
144
  utterance = item.get("utterance", "").strip()
145
  emotion = item.get("emotion", "")
146
 
147
  if i not in valid_indices:
148
- # Invalid input
149
  results.append({
150
  "utterance": utterance,
151
  "emotion": emotion,
@@ -153,32 +139,32 @@ class EndpointHandler:
153
  "triggers": []
154
  })
155
  else:
156
- # Valid input - process prediction
157
- prediction = predictions[result_idx]
158
- answer = prediction.get('answer')
159
-
160
- # Extract and clean spans
161
- if isinstance(answer, list) and len(answer) > 0:
162
- non_empty_answers = [a for a in answer if a]
163
- triggers = self._clean_spans(non_empty_answers, utterance)
164
- elif isinstance(answer, str):
165
- triggers = self._clean_spans([answer], utterance)
166
- else:
167
- triggers = []
 
168
 
169
  results.append({
170
  "utterance": utterance,
171
  "emotion": emotion,
172
  "triggers": triggers
173
  })
174
- result_idx += 1
175
 
176
  logger.debug(f"Cleaned results: {results}")
177
  return results
178
 
179
  except Exception as e:
180
  logger.error(f"Model prediction failed: {e}")
181
- # Return error for all inputs
182
  return [{
183
  "utterance": item.get("utterance", ""),
184
  "emotion": item.get("emotion", ""),
@@ -189,36 +175,23 @@ class EndpointHandler:
189
  def _clean_spans(self, spans: List[str], target_text: str) -> List[str]:
190
  """
191
  Clean and filter extracted trigger spans.
192
-
193
- This function preserves all the post-processing logic from predict_trigger.py
194
- (lines 78-153) including stopword filtering, length constraints, deduplication,
195
- and n-gram fallback.
196
-
197
- Args:
198
- spans: Raw spans extracted by the model
199
- target_text: Original utterance text
200
-
201
- Returns:
202
- List of up to 3 cleaned trigger phrases
203
  """
204
  target_text = target_text or ""
205
  target_lower = target_text.lower()
206
 
207
  def _norm(s: str) -> str:
208
- """Normalize a string: strip, lowercase, remove extra spaces and punctuation."""
209
  s = (s or "").strip().lower()
210
  s = re.sub(r"\s+", " ", s)
211
  s = re.sub(r"^[^\w]+|[^\w]+$", "", s)
212
  return s
213
 
214
  def _extract_from_target(target: str, phrase_lower: str) -> str:
215
- """Extract phrase from target with original casing."""
216
  idx = target.lower().find(phrase_lower)
217
  if idx >= 0:
218
  return target[idx:idx+len(phrase_lower)]
219
  return phrase_lower
220
 
221
- # Stopwords to filter out
222
  STOP = {
223
  "a", "an", "the", "and", "or", "but", "so", "to", "of", "in", "on", "at",
224
  "with", "for", "from", "is", "am", "are", "was", "were", "be", "been",
@@ -227,7 +200,6 @@ class EndpointHandler:
227
  "those"
228
  }
229
 
230
- # Collect candidate spans that are substrings of target and reasonable length
231
  candidates = []
232
  for s in spans:
233
  s = (s or "").strip()
@@ -250,7 +222,6 @@ class EndpointHandler:
250
  "char_len": len(s_norm)
251
  })
252
 
253
- # Prefer longer phrases; remove subsumed/duplicate fragments
254
  candidates.sort(key=lambda x: (x["tok_len"], x["char_len"]), reverse=True)
255
  kept_norms = []
256
  for c in list(candidates):
@@ -262,8 +233,6 @@ class EndpointHandler:
262
  cleaned = [_extract_from_target(target_text, n) for n in kept_norms]
263
 
264
  if not cleaned and spans:
265
- # Fallback: try to salvage a sub-span that actually exists
266
- # in the target utterance by scanning n-grams up to 8 words
267
  tt_tokens = target_lower.split()
268
  best = None
269
  for s in spans:
@@ -271,7 +240,6 @@ class EndpointHandler:
271
  for L in range(min(8, len(words)), 0, -1):
272
  for i in range(len(words) - L + 1):
273
  phrase = words[i:i+L]
274
- # contiguous n-gram match on token boundaries
275
  for j in range(len(tt_tokens) - L + 1):
276
  if tt_tokens[j:j+L] == phrase:
277
  cand = " ".join(phrase)
@@ -284,4 +252,4 @@ class EndpointHandler:
284
  if best:
285
  return [_extract_from_target(target_text, best)]
286
 
287
- return cleaned[:3]
 
2
  import logging
3
  import re
4
  from typing import Dict, List, Any
5
+ from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
6
 
7
+ # Configure logging
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
  """
14
+ Initialize the RECCON emotional trigger extraction model using native transformers.
 
15
  Args:
16
  path: Path to model directory (provided by HuggingFace Inference Endpoints)
17
  """
 
21
  cuda_available = torch.cuda.is_available()
22
  if not cuda_available:
23
  logger.warning("GPU not detected. Running on CPU. Inference will be slower.")
24
+
25
+ # In 'pipeline', device is an integer (-1 for CPU, 0+ for GPU)
26
+ self.device_id = 0 if cuda_available else -1
27
 
28
  # Determine model path
29
+ model_path = path if path and path != "." else "."
 
 
 
 
30
  logger.info(f"Loading model from {model_path}...")
31
 
 
32
  try:
33
+ # Load tokenizer and model explicitly to ensure correct loading
34
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
35
+ model = AutoModelForQuestionAnswering.from_pretrained(model_path)
36
+
37
+ # Initialize the pipeline
38
+ # top_k=20 matches your previous 'n_best_size=20' logic
39
+ self.pipe = pipeline(
40
+ "question-answering",
41
+ model=model,
42
+ tokenizer=tokenizer,
43
+ device=self.device_id,
44
+ top_k=20,
45
+ handle_impossible_answer=False
 
46
  )
47
  logger.info("Model loaded successfully.")
48
  except Exception as e:
 
59
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
60
  """
61
  Process inference request.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  """
63
  # Extract inputs
64
  inputs = data.pop("inputs", data)
65
 
66
+ # Normalize to list format
67
  if isinstance(inputs, dict):
68
  inputs = [inputs]
69
 
70
  if not inputs:
71
  return [{"error": "No inputs provided", "triggers": []}]
72
 
73
+ # Validate and format inputs for the pipeline
74
+ pipeline_inputs = []
75
  valid_indices = []
76
 
77
  for i, item in enumerate(inputs):
 
84
 
85
  # Format as QA task
86
  question = self.question_template.format(emotion=emotion)
87
+
88
+ # The pipeline expects a list of dicts with 'question' and 'context'
89
+ pipeline_inputs.append({
90
+ 'question': question,
91
+ 'context': utterance
 
92
  })
93
  valid_indices.append(i)
94
 
95
  # Run prediction
96
  results = []
97
 
98
+ if not pipeline_inputs:
99
  # All inputs were invalid
100
  for item in inputs:
101
  results.append({
 
107
  return results
108
 
109
  try:
110
+ # Run inference (batch_size helps with multiple inputs)
111
+ predictions = self.pipe(pipeline_inputs, batch_size=8)
112
+
113
+ # If batch_size=1 or single input, pipeline might return a single list/dict
114
+ # We ensure it's a list of lists (since top_k > 1)
115
+ if isinstance(predictions, dict): # Single input result
116
+ predictions = [predictions] # Wrap in list
117
+ elif isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict):
118
+ # This happens if we have multiple inputs but top_k=1 (which is not the case here),
119
+ # OR if we have a single input and top_k > 1.
120
+ # If we have multiple inputs and top_k > 1, it returns a list of lists.
121
+ if len(pipeline_inputs) == 1:
122
+ predictions = [predictions]
123
+ # If multiple inputs and list of dicts, that implies top_k=1.
124
+ # But we set top_k=20. So it should be list of lists.
125
+
126
  logger.debug(f"Raw predictions: {predictions}")
127
 
128
  # Post-process results
129
+ pred_idx = 0
130
  for i, item in enumerate(inputs):
131
  utterance = item.get("utterance", "").strip()
132
  emotion = item.get("emotion", "")
133
 
134
  if i not in valid_indices:
 
135
  results.append({
136
  "utterance": utterance,
137
  "emotion": emotion,
 
139
  "triggers": []
140
  })
141
  else:
142
+ # Get prediction for this item
143
+ # Because top_k=20, 'current_preds' is a list of dicts: [{'answer': '...', 'score': ...}, ...]
144
+ current_preds = predictions[pred_idx]
145
+
146
+ # Ensure it is a list
147
+ if isinstance(current_preds, dict):
148
+ current_preds = [current_preds]
149
+
150
+ # Extract the answer strings
151
+ raw_answers = [p.get('answer', '') for p in current_preds]
152
+
153
+ # Clean spans using your original logic
154
+ triggers = self._clean_spans(raw_answers, utterance)
155
 
156
  results.append({
157
  "utterance": utterance,
158
  "emotion": emotion,
159
  "triggers": triggers
160
  })
161
+ pred_idx += 1
162
 
163
  logger.debug(f"Cleaned results: {results}")
164
  return results
165
 
166
  except Exception as e:
167
  logger.error(f"Model prediction failed: {e}")
 
168
  return [{
169
  "utterance": item.get("utterance", ""),
170
  "emotion": item.get("emotion", ""),
 
175
  def _clean_spans(self, spans: List[str], target_text: str) -> List[str]:
176
  """
177
  Clean and filter extracted trigger spans.
178
+ (Logic preserved exactly as provided)
 
 
 
 
 
 
 
 
 
 
179
  """
180
  target_text = target_text or ""
181
  target_lower = target_text.lower()
182
 
183
  def _norm(s: str) -> str:
 
184
  s = (s or "").strip().lower()
185
  s = re.sub(r"\s+", " ", s)
186
  s = re.sub(r"^[^\w]+|[^\w]+$", "", s)
187
  return s
188
 
189
  def _extract_from_target(target: str, phrase_lower: str) -> str:
 
190
  idx = target.lower().find(phrase_lower)
191
  if idx >= 0:
192
  return target[idx:idx+len(phrase_lower)]
193
  return phrase_lower
194
 
 
195
  STOP = {
196
  "a", "an", "the", "and", "or", "but", "so", "to", "of", "in", "on", "at",
197
  "with", "for", "from", "is", "am", "are", "was", "were", "be", "been",
 
200
  "those"
201
  }
202
 
 
203
  candidates = []
204
  for s in spans:
205
  s = (s or "").strip()
 
222
  "char_len": len(s_norm)
223
  })
224
 
 
225
  candidates.sort(key=lambda x: (x["tok_len"], x["char_len"]), reverse=True)
226
  kept_norms = []
227
  for c in list(candidates):
 
233
  cleaned = [_extract_from_target(target_text, n) for n in kept_norms]
234
 
235
  if not cleaned and spans:
 
 
236
  tt_tokens = target_lower.split()
237
  best = None
238
  for s in spans:
 
240
  for L in range(min(8, len(words)), 0, -1):
241
  for i in range(len(words) - L + 1):
242
  phrase = words[i:i+L]
 
243
  for j in range(len(tt_tokens) - L + 1):
244
  if tt_tokens[j:j+L] == phrase:
245
  cand = " ".join(phrase)
 
252
  if best:
253
  return [_extract_from_target(target_text, best)]
254
 
255
+ return cleaned[:3]