PyTorch
bart
jsoars commited on
Commit
b31027a
·
verified ·
1 Parent(s): 7c8df7d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +227 -48
handler.py CHANGED
@@ -1,6 +1,9 @@
1
- from typing import Any, Dict, List
 
 
 
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
 
6
  class EndpointHandler:
@@ -19,56 +22,199 @@ class EndpointHandler:
19
  "keywords:",
20
  ]
21
 
22
- def _clean_keywords(self, raw_text: str, source_text: str) -> List[str]:
23
- source_lower = source_text.lower().strip()
24
- raw_parts = [part.strip() for part in raw_text.split(";") if part.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- seen = set()
27
- cleaned: List[str] = []
 
 
 
 
 
28
 
29
- for kw in raw_parts:
30
- kw_clean = " ".join(kw.split()).strip()
31
- kw_lower = kw_clean.lower()
32
 
33
- if not kw_clean:
34
- continue
 
 
35
 
36
- # Remove instruction leakage
37
- if any(kw_lower.startswith(prefix) for prefix in self.bad_prefixes):
38
- continue
39
 
40
- # Remove exact/near-full input echoes
41
- if kw_lower == source_lower:
42
- continue
43
- if len(kw_lower) > 30 and kw_lower in source_lower:
44
- continue
45
- if len(source_lower) > 30 and source_lower in kw_lower:
46
- continue
47
 
48
- # Skip very long outputs that are likely sentence fragments, not keywords
49
- if len(kw_clean.split()) > 6:
50
- continue
 
 
 
 
 
51
 
52
- # Skip obvious clause/sentence-like phrases
53
- sentence_markers = [" and ", " because ", " that ", " which ", " where ", " when "]
54
- if any(marker in kw_lower for marker in sentence_markers) and len(kw_clean.split()) > 4:
55
- continue
56
 
57
- # Trim surrounding punctuation
58
- kw_clean = kw_clean.strip(" ,.;:-")
 
59
 
60
- if not kw_clean:
61
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- # Dedupe case-insensitively
64
- normalized = kw_clean.lower()
65
- if normalized in seen:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- seen.add(normalized)
69
- cleaned.append(kw_clean)
 
 
70
 
71
- return cleaned
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
74
  text = data.get("inputs")
@@ -81,11 +227,20 @@ class EndpointHandler:
81
  parameters = data.get("parameters", {})
82
 
83
  max_input_length = int(parameters.get("max_input_length", 1024))
84
- max_new_tokens = int(parameters.get("max_new_tokens", 48))
85
- num_beams = int(parameters.get("num_beams", 4))
 
86
  do_sample = bool(parameters.get("do_sample", False))
87
- temperature = float(parameters.get("temperature", 1.0))
88
- no_repeat_ngram_size = int(parameters.get("no_repeat_ngram_size", 3))
 
 
 
 
 
 
 
 
89
 
90
  encoded = self.tokenizer(
91
  text,
@@ -99,6 +254,7 @@ class EndpointHandler:
99
  **encoded,
100
  "max_new_tokens": max_new_tokens,
101
  "num_beams": num_beams,
 
102
  "do_sample": do_sample,
103
  "no_repeat_ngram_size": no_repeat_ngram_size,
104
  "early_stopping": True,
@@ -106,14 +262,37 @@ class EndpointHandler:
106
 
107
  if do_sample:
108
  generate_kwargs["temperature"] = temperature
 
109
 
110
  with torch.inference_mode():
111
  output_ids = self.model.generate(**generate_kwargs)
112
 
113
- raw_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
114
- keywords = self._clean_keywords(raw_text, text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- return {
117
- "generated_text": raw_text,
118
  "keywords": keywords,
119
- }
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+ import math
3
+ import re
4
+
5
  import torch
6
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
 
8
 
9
  class EndpointHandler:
 
22
  "keywords:",
23
  ]
24
 
25
+ self.generic_phrases = {
26
+ "new platform",
27
+ "platform",
28
+ "company",
29
+ "market",
30
+ "markets",
31
+ "system",
32
+ "technology",
33
+ "solution",
34
+ "services",
35
+ "service",
36
+ "product",
37
+ "products",
38
+ "tool",
39
+ "tools",
40
+ }
41
 
42
+ self.stopwords = {
43
+ "a", "an", "the", "and", "or", "of", "for", "to", "in", "on", "with",
44
+ "by", "at", "from", "into", "over", "under", "through", "across",
45
+ "is", "are", "was", "were", "be", "been", "being",
46
+ "this", "that", "these", "those", "it", "its", "their",
47
+ "new", "latest"
48
+ }
49
 
50
+ def _normalize_space(self, text: str) -> str:
51
+ return " ".join(text.split()).strip()
 
52
 
53
+ def _normalize_phrase(self, text: str) -> str:
54
+ text = self._normalize_space(text)
55
+ text = text.strip(" ,.;:-_")
56
+ return text
57
 
58
+ def _phrase_tokens(self, text: str) -> List[str]:
59
+ return re.findall(r"[A-Za-z0-9][A-Za-z0-9\-+/.]*", text.lower())
 
60
 
61
+ def _contains_instruction_leakage(self, phrase_lower: str) -> bool:
62
+ return any(phrase_lower.startswith(prefix) for prefix in self.bad_prefixes)
 
 
 
 
 
63
 
64
+ def _looks_sentence_like(self, phrase: str) -> bool:
65
+ lower = phrase.lower()
66
+ markers = [" and ", " because ", " which ", " where ", " when ", " while ", " after ", " before "]
67
+ if any(m in lower for m in markers) and len(phrase.split()) > 4:
68
+ return True
69
+ if phrase.endswith("."):
70
+ return True
71
+ return False
72
 
73
+ def _is_too_generic(self, phrase: str) -> bool:
74
+ lower = phrase.lower()
75
+ if lower in self.generic_phrases:
76
+ return True
77
 
78
+ tokens = self._phrase_tokens(lower)
79
+ if len(tokens) == 1 and tokens[0] in self.generic_phrases:
80
+ return True
81
 
82
+ # phrases like "new platform" or "new system"
83
+ if len(tokens) == 2 and tokens[0] in {"new", "latest"} and tokens[1] in self.generic_phrases:
84
+ return True
85
+
86
+ return False
87
+
88
+ def _jaccard(self, a: List[str], b: List[str]) -> float:
89
+ sa, sb = set(a), set(b)
90
+ if not sa or not sb:
91
+ return 0.0
92
+ return len(sa & sb) / len(sa | sb)
93
+
94
+ def _text_coverage_score(self, phrase: str, source_text: str) -> float:
95
+ """
96
+ Soft relevance score using literal presence and token overlap.
97
+ Keeps semantically good present phrases near the top.
98
+ """
99
+ phrase_lower = phrase.lower()
100
+ source_lower = source_text.lower()
101
+
102
+ score = 0.0
103
+
104
+ if phrase_lower in source_lower:
105
+ score += 4.0
106
+
107
+ phrase_tokens = self._phrase_tokens(phrase)
108
+ source_tokens = self._phrase_tokens(source_text)
109
+
110
+ if not phrase_tokens:
111
+ return 0.0
112
+
113
+ overlap = len(set(phrase_tokens) & set(source_tokens))
114
+ score += overlap * 1.25
115
+ score += self._jaccard(phrase_tokens, source_tokens) * 2.0
116
+
117
+ # prefer 2–3 word phrases slightly
118
+ wc = len(phrase.split())
119
+ if wc == 2:
120
+ score += 1.0
121
+ elif wc == 3:
122
+ score += 0.75
123
+ elif wc == 1:
124
+ score += 0.25
125
+ elif wc >= 5:
126
+ score -= 1.0
127
+
128
+ # penalize generic lead words
129
+ if phrase_tokens and phrase_tokens[0] in self.stopwords:
130
+ score -= 0.75
131
+
132
+ return score
133
+
134
+ def _parse_candidates(self, generated_texts: List[str], source_text: str, max_keyword_words: int) -> List[str]:
135
+ source_lower = self._normalize_space(source_text.lower())
136
+ candidates: List[str] = []
137
+
138
+ for raw_text in generated_texts:
139
+ parts = [self._normalize_phrase(p) for p in raw_text.split(";")]
140
+ for part in parts:
141
+ if not part:
142
+ continue
143
+
144
+ lower = part.lower()
145
 
146
+ if self._contains_instruction_leakage(lower):
147
+ continue
148
+
149
+ if lower == source_lower:
150
+ continue
151
+
152
+ if len(lower) > 30 and lower in source_lower:
153
+ # likely near-complete echo
154
+ continue
155
+
156
+ if self._looks_sentence_like(part):
157
+ continue
158
+
159
+ wc = len(part.split())
160
+ if wc == 0 or wc > max_keyword_words:
161
+ continue
162
+
163
+ if self._is_too_generic(part):
164
+ continue
165
+
166
+ candidates.append(part)
167
+
168
+ return candidates
169
+
170
+ def _dedupe_and_prune(self, phrases: List[str], source_text: str, top_k: int) -> List[Tuple[str, float]]:
171
+ # First score
172
+ scored: List[Tuple[str, float]] = []
173
+ seen_exact = set()
174
+
175
+ for phrase in phrases:
176
+ norm = phrase.lower()
177
+ if norm in seen_exact:
178
  continue
179
+ seen_exact.add(norm)
180
+
181
+ score = self._text_coverage_score(phrase, source_text)
182
+ if score > 0:
183
+ scored.append((phrase, score))
184
+
185
+ # Sort best first
186
+ scored.sort(key=lambda x: x[1], reverse=True)
187
+
188
+ # Remove subsumed / near-duplicate phrases
189
+ final_scored: List[Tuple[str, float]] = []
190
+ for phrase, score in scored:
191
+ ptoks = self._phrase_tokens(phrase)
192
+ pset = set(ptoks)
193
+
194
+ should_skip = False
195
+ for kept_phrase, kept_score in final_scored:
196
+ ktoks = self._phrase_tokens(kept_phrase)
197
+ kset = set(ktoks)
198
 
199
+ # exact token subset of a better phrase -> drop shorter one
200
+ if pset and pset.issubset(kset):
201
+ should_skip = True
202
+ break
203
 
204
+ # heavy overlap and shorter/weaker -> drop
205
+ jac = self._jaccard(ptoks, ktoks)
206
+ if jac >= 0.6:
207
+ if len(ptoks) <= len(ktoks) and score <= kept_score + 0.5:
208
+ should_skip = True
209
+ break
210
+
211
+ if not should_skip:
212
+ final_scored.append((phrase, round(score, 4)))
213
+
214
+ if len(final_scored) >= top_k:
215
+ break
216
+
217
+ return final_scored
218
 
219
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
220
  text = data.get("inputs")
 
227
  parameters = data.get("parameters", {})
228
 
229
  max_input_length = int(parameters.get("max_input_length", 1024))
230
+ max_new_tokens = int(parameters.get("max_new_tokens", 32))
231
+ num_beams = int(parameters.get("num_beams", 6))
232
+ num_return_sequences = int(parameters.get("num_return_sequences", 4))
233
  do_sample = bool(parameters.get("do_sample", False))
234
+ temperature = float(parameters.get("temperature", 0.9))
235
+ top_p = float(parameters.get("top_p", 0.95))
236
+ no_repeat_ngram_size = int(parameters.get("no_repeat_ngram_size", 2))
237
+ max_keyword_words = int(parameters.get("max_keyword_words", 4))
238
+ top_k_keywords = int(parameters.get("top_k_keywords", 6))
239
+ return_scores = bool(parameters.get("return_scores", False))
240
+
241
+ if not do_sample:
242
+ # beam search requires return_sequences <= beams
243
+ num_return_sequences = min(num_return_sequences, num_beams)
244
 
245
  encoded = self.tokenizer(
246
  text,
 
254
  **encoded,
255
  "max_new_tokens": max_new_tokens,
256
  "num_beams": num_beams,
257
+ "num_return_sequences": num_return_sequences,
258
  "do_sample": do_sample,
259
  "no_repeat_ngram_size": no_repeat_ngram_size,
260
  "early_stopping": True,
 
262
 
263
  if do_sample:
264
  generate_kwargs["temperature"] = temperature
265
+ generate_kwargs["top_p"] = top_p
266
 
267
  with torch.inference_mode():
268
  output_ids = self.model.generate(**generate_kwargs)
269
 
270
+ generated_texts = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
271
+ generated_texts = [self._normalize_space(t) for t in generated_texts if self._normalize_space(t)]
272
+
273
+ candidates = self._parse_candidates(
274
+ generated_texts=generated_texts,
275
+ source_text=text,
276
+ max_keyword_words=max_keyword_words,
277
+ )
278
+
279
+ ranked = self._dedupe_and_prune(
280
+ phrases=candidates,
281
+ source_text=text,
282
+ top_k=top_k_keywords,
283
+ )
284
+
285
+ keywords = [phrase for phrase, _ in ranked]
286
 
287
+ response: Dict[str, Any] = {
288
+ "generated_texts": generated_texts,
289
  "keywords": keywords,
290
+ }
291
+
292
+ if return_scores:
293
+ response["keyword_scores"] = [
294
+ {"keyword": phrase, "score": score}
295
+ for phrase, score in ranked
296
+ ]
297
+
298
+ return response