PyTorch
bart
jsoars commited on
Commit
7c8df7d
·
verified ·
1 Parent(s): e8d0baa

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +84 -21
handler.py CHANGED
@@ -12,45 +12,108 @@ class EndpointHandler:
12
  self.model.to(self.device)
13
  self.model.eval()
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
16
  text = data.get("inputs")
17
  if text is None:
18
  return {"error": "Missing required field: inputs"}
19
 
 
 
 
20
  parameters = data.get("parameters", {})
21
 
 
 
 
 
 
 
 
22
  encoded = self.tokenizer(
23
  text,
24
  return_tensors="pt",
25
  truncation=True,
26
- max_length=int(parameters.get("max_input_length", 1024)),
27
  )
28
  encoded = {k: v.to(self.device) for k, v in encoded.items()}
29
 
30
- with torch.inference_mode():
31
- output_ids = self.model.generate(
32
- **encoded,
33
- max_new_tokens=int(parameters.get("max_new_tokens", 48)),
34
- num_beams=int(parameters.get("num_beams", 4)),
35
- do_sample=bool(parameters.get("do_sample", False)),
36
- temperature=float(parameters.get("temperature", 1.0)),
37
- no_repeat_ngram_size=int(parameters.get("no_repeat_ngram_size", 3)),
38
- early_stopping=True,
39
- )
40
 
41
- raw_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
 
42
 
43
- keywords = [x.strip() for x in raw_text.split(";") if x.strip()]
 
44
 
45
- seen = set()
46
- deduped: List[str] = []
47
- for kw in keywords:
48
- k = kw.lower()
49
- if k not in seen:
50
- seen.add(k)
51
- deduped.append(kw)
52
 
53
  return {
54
  "generated_text": raw_text,
55
- "keywords": deduped,
56
  }
 
12
  self.model.to(self.device)
13
  self.model.eval()
14
 
15
+ self.bad_prefixes = [
16
+ "extract keyphrases:",
17
+ "extract keywords:",
18
+ "keyphrases:",
19
+ "keywords:",
20
+ ]
21
+
22
+ def _clean_keywords(self, raw_text: str, source_text: str) -> List[str]:
23
+ source_lower = source_text.lower().strip()
24
+ raw_parts = [part.strip() for part in raw_text.split(";") if part.strip()]
25
+
26
+ seen = set()
27
+ cleaned: List[str] = []
28
+
29
+ for kw in raw_parts:
30
+ kw_clean = " ".join(kw.split()).strip()
31
+ kw_lower = kw_clean.lower()
32
+
33
+ if not kw_clean:
34
+ continue
35
+
36
+ # Remove instruction leakage
37
+ if any(kw_lower.startswith(prefix) for prefix in self.bad_prefixes):
38
+ continue
39
+
40
+ # Remove exact/near-full input echoes
41
+ if kw_lower == source_lower:
42
+ continue
43
+ if len(kw_lower) > 30 and kw_lower in source_lower:
44
+ continue
45
+ if len(source_lower) > 30 and source_lower in kw_lower:
46
+ continue
47
+
48
+ # Skip very long outputs that are likely sentence fragments, not keywords
49
+ if len(kw_clean.split()) > 6:
50
+ continue
51
+
52
+ # Skip obvious clause/sentence-like phrases
53
+ sentence_markers = [" and ", " because ", " that ", " which ", " where ", " when "]
54
+ if any(marker in kw_lower for marker in sentence_markers) and len(kw_clean.split()) > 4:
55
+ continue
56
+
57
+ # Trim surrounding punctuation
58
+ kw_clean = kw_clean.strip(" ,.;:-")
59
+
60
+ if not kw_clean:
61
+ continue
62
+
63
+ # Dedupe case-insensitively
64
+ normalized = kw_clean.lower()
65
+ if normalized in seen:
66
+ continue
67
+
68
+ seen.add(normalized)
69
+ cleaned.append(kw_clean)
70
+
71
+ return cleaned
72
+
73
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
74
  text = data.get("inputs")
75
  if text is None:
76
  return {"error": "Missing required field: inputs"}
77
 
78
+ if not isinstance(text, str):
79
+ return {"error": "The 'inputs' field must be a string"}
80
+
81
  parameters = data.get("parameters", {})
82
 
83
+ max_input_length = int(parameters.get("max_input_length", 1024))
84
+ max_new_tokens = int(parameters.get("max_new_tokens", 48))
85
+ num_beams = int(parameters.get("num_beams", 4))
86
+ do_sample = bool(parameters.get("do_sample", False))
87
+ temperature = float(parameters.get("temperature", 1.0))
88
+ no_repeat_ngram_size = int(parameters.get("no_repeat_ngram_size", 3))
89
+
90
  encoded = self.tokenizer(
91
  text,
92
  return_tensors="pt",
93
  truncation=True,
94
+ max_length=max_input_length,
95
  )
96
  encoded = {k: v.to(self.device) for k, v in encoded.items()}
97
 
98
+ generate_kwargs = {
99
+ **encoded,
100
+ "max_new_tokens": max_new_tokens,
101
+ "num_beams": num_beams,
102
+ "do_sample": do_sample,
103
+ "no_repeat_ngram_size": no_repeat_ngram_size,
104
+ "early_stopping": True,
105
+ }
 
 
106
 
107
+ if do_sample:
108
+ generate_kwargs["temperature"] = temperature
109
 
110
+ with torch.inference_mode():
111
+ output_ids = self.model.generate(**generate_kwargs)
112
 
113
+ raw_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
114
+ keywords = self._clean_keywords(raw_text, text)
 
 
 
 
 
115
 
116
  return {
117
  "generated_text": raw_text,
118
+ "keywords": keywords,
119
  }