Khriis commited on
Commit
1319732
verified
1 Parent(s): de42880

Uploaded infrastructure and model files

Browse files
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForQuestionAnswering"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "classifier_dropout": null,
8
+ "dtype": "float32",
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 768,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 514,
17
+ "model_type": "roberta",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "position_embedding_type": "absolute",
22
+ "transformers_version": "4.57.1",
23
+ "type_vocab_size": 1,
24
+ "use_cache": true,
25
+ "vocab_size": 50265
26
+ }
handler.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import re
4
+ from typing import Dict, List, Any
5
+ from simpletransformers.question_answering import QuestionAnsweringModel
6
+
7
+ # Configure logging (no file I/O for serverless environment)
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class EndpointHandler:
13
+ def __init__(self, path=""):
14
+ """
15
+ Initialize the RECCON emotional trigger extraction model.
16
+
17
+ Args:
18
+ path: Path to model directory (provided by HuggingFace Inference Endpoints)
19
+ """
20
+ logger.info("Initializing RECCON Trigger Extraction endpoint...")
21
+
22
+ # Detect device (CUDA/CPU)
23
+ cuda_available = torch.cuda.is_available()
24
+ if not cuda_available:
25
+ logger.warning("GPU not detected. Running on CPU. Inference will be slower.")
26
+ self.device = torch.device("cuda" if cuda_available else "cpu")
27
+ cuda_device = 0 if cuda_available else -1
28
+
29
+ # Determine model path
30
+ if not path or path == ".":
31
+ model_path = "."
32
+ else:
33
+ model_path = path
34
+
35
+ logger.info(f"Loading model from {model_path}...")
36
+
37
+ # Load the QuestionAnsweringModel using simpletransformers
38
+ try:
39
+ self.model = QuestionAnsweringModel(
40
+ "roberta",
41
+ model_path,
42
+ args={
43
+ "silent_tf_logger": True,
44
+ "eval_batch_size": 8,
45
+ "device_map": None,
46
+ "max_seq_length": 512,
47
+ "max_answer_length": 200,
48
+ "n_best_size": 20,
49
+ "doc_stride": 512
50
+ },
51
+ use_cuda=cuda_available,
52
+ cuda_device=cuda_device
53
+ )
54
+ logger.info("Model loaded successfully.")
55
+ except Exception as e:
56
+ logger.error(f"Failed to load model: {e}")
57
+ raise
58
+
59
+ # Question template (must match training)
60
+ self.question_template = (
61
+ "Extract the exact short phrase (<= 8 words) from the target "
62
+ "utterance that most strongly signals the emotion {emotion}. "
63
+ "Return only a substring of the target utterance."
64
+ )
65
+
66
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
67
+ """
68
+ Process inference request.
69
+
70
+ Args:
71
+ data: Request data with structure:
72
+ {
73
+ "inputs": [
74
+ {"utterance": "text", "emotion": "happiness"},
75
+ ...
76
+ ]
77
+ }
78
+
79
+ Returns:
80
+ List of results:
81
+ [
82
+ {
83
+ "utterance": "text",
84
+ "emotion": "happiness",
85
+ "triggers": ["trigger phrase 1", "trigger phrase 2"]
86
+ },
87
+ ...
88
+ ]
89
+ """
90
+ # Extract inputs
91
+ inputs = data.pop("inputs", data)
92
+
93
+ # Normalize to list format (handle single dict)
94
+ if isinstance(inputs, dict):
95
+ inputs = [inputs]
96
+
97
+ if not inputs:
98
+ return [{"error": "No inputs provided", "triggers": []}]
99
+
100
+ # Validate and format inputs
101
+ qa_inputs = []
102
+ valid_indices = []
103
+
104
+ for i, item in enumerate(inputs):
105
+ utterance = item.get("utterance", "").strip()
106
+ emotion = item.get("emotion", "")
107
+
108
+ if not utterance:
109
+ logger.warning(f"Empty utterance at index {i}")
110
+ continue
111
+
112
+ # Format as QA task
113
+ question = self.question_template.format(emotion=emotion)
114
+ qa_inputs.append({
115
+ 'context': utterance,
116
+ 'qas': [{
117
+ 'id': f'temp_id_{i}',
118
+ 'question': question
119
+ }]
120
+ })
121
+ valid_indices.append(i)
122
+
123
+ # Run prediction
124
+ results = []
125
+
126
+ if not qa_inputs:
127
+ # All inputs were invalid
128
+ for item in inputs:
129
+ results.append({
130
+ "utterance": item.get("utterance", ""),
131
+ "emotion": item.get("emotion", ""),
132
+ "error": "Missing or empty utterance",
133
+ "triggers": []
134
+ })
135
+ return results
136
+
137
+ try:
138
+ predictions, _ = self.model.predict(qa_inputs)
139
+ logger.debug(f"Raw predictions: {predictions}")
140
+
141
+ # Post-process results
142
+ result_idx = 0
143
+ for i, item in enumerate(inputs):
144
+ utterance = item.get("utterance", "").strip()
145
+ emotion = item.get("emotion", "")
146
+
147
+ if i not in valid_indices:
148
+ # Invalid input
149
+ results.append({
150
+ "utterance": utterance,
151
+ "emotion": emotion,
152
+ "error": "Missing or empty utterance",
153
+ "triggers": []
154
+ })
155
+ else:
156
+ # Valid input - process prediction
157
+ prediction = predictions[result_idx]
158
+ answer = prediction.get('answer')
159
+
160
+ # Extract and clean spans
161
+ if isinstance(answer, list) and len(answer) > 0:
162
+ non_empty_answers = [a for a in answer if a]
163
+ triggers = self._clean_spans(non_empty_answers, utterance)
164
+ elif isinstance(answer, str):
165
+ triggers = self._clean_spans([answer], utterance)
166
+ else:
167
+ triggers = []
168
+
169
+ results.append({
170
+ "utterance": utterance,
171
+ "emotion": emotion,
172
+ "triggers": triggers
173
+ })
174
+ result_idx += 1
175
+
176
+ logger.debug(f"Cleaned results: {results}")
177
+ return results
178
+
179
+ except Exception as e:
180
+ logger.error(f"Model prediction failed: {e}")
181
+ # Return error for all inputs
182
+ return [{
183
+ "utterance": item.get("utterance", ""),
184
+ "emotion": item.get("emotion", ""),
185
+ "error": str(e),
186
+ "triggers": []
187
+ } for item in inputs]
188
+
189
+ def _clean_spans(self, spans: List[str], target_text: str) -> List[str]:
190
+ """
191
+ Clean and filter extracted trigger spans.
192
+
193
+ This function preserves all the post-processing logic from predict_trigger.py
194
+ (lines 78-153) including stopword filtering, length constraints, deduplication,
195
+ and n-gram fallback.
196
+
197
+ Args:
198
+ spans: Raw spans extracted by the model
199
+ target_text: Original utterance text
200
+
201
+ Returns:
202
+ List of up to 3 cleaned trigger phrases
203
+ """
204
+ target_text = target_text or ""
205
+ target_lower = target_text.lower()
206
+
207
+ def _norm(s: str) -> str:
208
+ """Normalize a string: strip, lowercase, remove extra spaces and punctuation."""
209
+ s = (s or "").strip().lower()
210
+ s = re.sub(r"\s+", " ", s)
211
+ s = re.sub(r"^[^\w]+|[^\w]+$", "", s)
212
+ return s
213
+
214
+ def _extract_from_target(target: str, phrase_lower: str) -> str:
215
+ """Extract phrase from target with original casing."""
216
+ idx = target.lower().find(phrase_lower)
217
+ if idx >= 0:
218
+ return target[idx:idx+len(phrase_lower)]
219
+ return phrase_lower
220
+
221
+ # Stopwords to filter out
222
+ STOP = {
223
+ "a", "an", "the", "and", "or", "but", "so", "to", "of", "in", "on", "at",
224
+ "with", "for", "from", "is", "am", "are", "was", "were", "be", "been",
225
+ "being", "i", "you", "he", "she", "it", "we", "they", "my", "your", "his",
226
+ "her", "their", "our", "me", "him", "her", "them", "this", "that", "these",
227
+ "those"
228
+ }
229
+
230
+ # Collect candidate spans that are substrings of target and reasonable length
231
+ candidates = []
232
+ for s in spans:
233
+ s = (s or "").strip()
234
+ if not s:
235
+ continue
236
+ s_norm = _norm(s)
237
+ if not s_norm:
238
+ continue
239
+ if target_text and s_norm not in target_lower:
240
+ continue
241
+ tokens = s_norm.split()
242
+ if len(tokens) > 8 or len(s_norm) > 80:
243
+ continue
244
+ if len(tokens) == 1 and (tokens[0] in STOP or len(tokens[0]) <= 2):
245
+ continue
246
+ candidates.append({
247
+ "norm": s_norm,
248
+ "tokens": tokens,
249
+ "tok_len": len(tokens),
250
+ "char_len": len(s_norm)
251
+ })
252
+
253
+ # Prefer longer phrases; remove subsumed/duplicate fragments
254
+ candidates.sort(key=lambda x: (x["tok_len"], x["char_len"]), reverse=True)
255
+ kept_norms = []
256
+ for c in list(candidates):
257
+ n = c["norm"]
258
+ if any(n in kn or kn in n for kn in kept_norms):
259
+ continue
260
+ kept_norms.append(n)
261
+
262
+ cleaned = [_extract_from_target(target_text, n) for n in kept_norms]
263
+
264
+ if not cleaned and spans:
265
+ # Fallback: try to salvage a sub-span that actually exists
266
+ # in the target utterance by scanning n-grams up to 8 words
267
+ tt_tokens = target_lower.split()
268
+ best = None
269
+ for s in spans:
270
+ words = [w for w in (s or '').lower().strip().split() if w]
271
+ for L in range(min(8, len(words)), 0, -1):
272
+ for i in range(len(words) - L + 1):
273
+ phrase = words[i:i+L]
274
+ # contiguous n-gram match on token boundaries
275
+ for j in range(len(tt_tokens) - L + 1):
276
+ if tt_tokens[j:j+L] == phrase:
277
+ cand = " ".join(phrase)
278
+ best = cand
279
+ break
280
+ if best:
281
+ break
282
+ if best:
283
+ break
284
+ if best:
285
+ return [_extract_from_target(target_text, best)]
286
+
287
+ return cleaned[:3]
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e82e93aaf74df78452904601c8ba6502a1e4b90bd9b26ed55ddfe0a279e8fc18
3
+ size 496250232
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers>=4.30.0,<5.0.0
2
+ torch>=2.0.0
3
+ simpletransformers>=0.64.0
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": true,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": true,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<pad>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "50264": {
37
+ "content": "<mask>",
38
+ "lstrip": true,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ }
44
+ },
45
+ "bos_token": "<s>",
46
+ "clean_up_tokenization_spaces": false,
47
+ "cls_token": "<s>",
48
+ "do_lower_case": false,
49
+ "eos_token": "</s>",
50
+ "errors": "replace",
51
+ "extra_special_tokens": {},
52
+ "mask_token": "<mask>",
53
+ "model_max_length": 512,
54
+ "pad_token": "<pad>",
55
+ "sep_token": "</s>",
56
+ "tokenizer_class": "RobertaTokenizer",
57
+ "unk_token": "<unk>"
58
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff