PyTorch
bart
KeyBART / handler.py
jsoars's picture
Update handler.py
b31027a verified
from typing import Any, Dict, List, Tuple
import math
import re
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path: str = ""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.model.eval()
self.bad_prefixes = [
"extract keyphrases:",
"extract keywords:",
"keyphrases:",
"keywords:",
]
self.generic_phrases = {
"new platform",
"platform",
"company",
"market",
"markets",
"system",
"technology",
"solution",
"services",
"service",
"product",
"products",
"tool",
"tools",
}
self.stopwords = {
"a", "an", "the", "and", "or", "of", "for", "to", "in", "on", "with",
"by", "at", "from", "into", "over", "under", "through", "across",
"is", "are", "was", "were", "be", "been", "being",
"this", "that", "these", "those", "it", "its", "their",
"new", "latest"
}
def _normalize_space(self, text: str) -> str:
return " ".join(text.split()).strip()
def _normalize_phrase(self, text: str) -> str:
text = self._normalize_space(text)
text = text.strip(" ,.;:-_")
return text
def _phrase_tokens(self, text: str) -> List[str]:
return re.findall(r"[A-Za-z0-9][A-Za-z0-9\-+/.]*", text.lower())
def _contains_instruction_leakage(self, phrase_lower: str) -> bool:
return any(phrase_lower.startswith(prefix) for prefix in self.bad_prefixes)
def _looks_sentence_like(self, phrase: str) -> bool:
lower = phrase.lower()
markers = [" and ", " because ", " which ", " where ", " when ", " while ", " after ", " before "]
if any(m in lower for m in markers) and len(phrase.split()) > 4:
return True
if phrase.endswith("."):
return True
return False
def _is_too_generic(self, phrase: str) -> bool:
lower = phrase.lower()
if lower in self.generic_phrases:
return True
tokens = self._phrase_tokens(lower)
if len(tokens) == 1 and tokens[0] in self.generic_phrases:
return True
# phrases like "new platform" or "new system"
if len(tokens) == 2 and tokens[0] in {"new", "latest"} and tokens[1] in self.generic_phrases:
return True
return False
def _jaccard(self, a: List[str], b: List[str]) -> float:
sa, sb = set(a), set(b)
if not sa or not sb:
return 0.0
return len(sa & sb) / len(sa | sb)
def _text_coverage_score(self, phrase: str, source_text: str) -> float:
"""
Soft relevance score using literal presence and token overlap.
Keeps semantically good present phrases near the top.
"""
phrase_lower = phrase.lower()
source_lower = source_text.lower()
score = 0.0
if phrase_lower in source_lower:
score += 4.0
phrase_tokens = self._phrase_tokens(phrase)
source_tokens = self._phrase_tokens(source_text)
if not phrase_tokens:
return 0.0
overlap = len(set(phrase_tokens) & set(source_tokens))
score += overlap * 1.25
score += self._jaccard(phrase_tokens, source_tokens) * 2.0
# prefer 2–3 word phrases slightly
wc = len(phrase.split())
if wc == 2:
score += 1.0
elif wc == 3:
score += 0.75
elif wc == 1:
score += 0.25
elif wc >= 5:
score -= 1.0
# penalize generic lead words
if phrase_tokens and phrase_tokens[0] in self.stopwords:
score -= 0.75
return score
def _parse_candidates(self, generated_texts: List[str], source_text: str, max_keyword_words: int) -> List[str]:
source_lower = self._normalize_space(source_text.lower())
candidates: List[str] = []
for raw_text in generated_texts:
parts = [self._normalize_phrase(p) for p in raw_text.split(";")]
for part in parts:
if not part:
continue
lower = part.lower()
if self._contains_instruction_leakage(lower):
continue
if lower == source_lower:
continue
if len(lower) > 30 and lower in source_lower:
# likely near-complete echo
continue
if self._looks_sentence_like(part):
continue
wc = len(part.split())
if wc == 0 or wc > max_keyword_words:
continue
if self._is_too_generic(part):
continue
candidates.append(part)
return candidates
def _dedupe_and_prune(self, phrases: List[str], source_text: str, top_k: int) -> List[Tuple[str, float]]:
# First score
scored: List[Tuple[str, float]] = []
seen_exact = set()
for phrase in phrases:
norm = phrase.lower()
if norm in seen_exact:
continue
seen_exact.add(norm)
score = self._text_coverage_score(phrase, source_text)
if score > 0:
scored.append((phrase, score))
# Sort best first
scored.sort(key=lambda x: x[1], reverse=True)
# Remove subsumed / near-duplicate phrases
final_scored: List[Tuple[str, float]] = []
for phrase, score in scored:
ptoks = self._phrase_tokens(phrase)
pset = set(ptoks)
should_skip = False
for kept_phrase, kept_score in final_scored:
ktoks = self._phrase_tokens(kept_phrase)
kset = set(ktoks)
# exact token subset of a better phrase -> drop shorter one
if pset and pset.issubset(kset):
should_skip = True
break
# heavy overlap and shorter/weaker -> drop
jac = self._jaccard(ptoks, ktoks)
if jac >= 0.6:
if len(ptoks) <= len(ktoks) and score <= kept_score + 0.5:
should_skip = True
break
if not should_skip:
final_scored.append((phrase, round(score, 4)))
if len(final_scored) >= top_k:
break
return final_scored
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
text = data.get("inputs")
if text is None:
return {"error": "Missing required field: inputs"}
if not isinstance(text, str):
return {"error": "The 'inputs' field must be a string"}
parameters = data.get("parameters", {})
max_input_length = int(parameters.get("max_input_length", 1024))
max_new_tokens = int(parameters.get("max_new_tokens", 32))
num_beams = int(parameters.get("num_beams", 6))
num_return_sequences = int(parameters.get("num_return_sequences", 4))
do_sample = bool(parameters.get("do_sample", False))
temperature = float(parameters.get("temperature", 0.9))
top_p = float(parameters.get("top_p", 0.95))
no_repeat_ngram_size = int(parameters.get("no_repeat_ngram_size", 2))
max_keyword_words = int(parameters.get("max_keyword_words", 4))
top_k_keywords = int(parameters.get("top_k_keywords", 6))
return_scores = bool(parameters.get("return_scores", False))
if not do_sample:
# beam search requires return_sequences <= beams
num_return_sequences = min(num_return_sequences, num_beams)
encoded = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=max_input_length,
)
encoded = {k: v.to(self.device) for k, v in encoded.items()}
generate_kwargs = {
**encoded,
"max_new_tokens": max_new_tokens,
"num_beams": num_beams,
"num_return_sequences": num_return_sequences,
"do_sample": do_sample,
"no_repeat_ngram_size": no_repeat_ngram_size,
"early_stopping": True,
}
if do_sample:
generate_kwargs["temperature"] = temperature
generate_kwargs["top_p"] = top_p
with torch.inference_mode():
output_ids = self.model.generate(**generate_kwargs)
generated_texts = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
generated_texts = [self._normalize_space(t) for t in generated_texts if self._normalize_space(t)]
candidates = self._parse_candidates(
generated_texts=generated_texts,
source_text=text,
max_keyword_words=max_keyword_words,
)
ranked = self._dedupe_and_prune(
phrases=candidates,
source_text=text,
top_k=top_k_keywords,
)
keywords = [phrase for phrase, _ in ranked]
response: Dict[str, Any] = {
"generated_texts": generated_texts,
"keywords": keywords,
}
if return_scores:
response["keyword_scores"] = [
{"keyword": phrase, "score": score}
for phrase, score in ranked
]
return response