from typing import Any, Dict, List, Tuple import math import re import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer class EndpointHandler: def __init__(self, path: str = ""): self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForSeq2SeqLM.from_pretrained(path) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) self.model.eval() self.bad_prefixes = [ "extract keyphrases:", "extract keywords:", "keyphrases:", "keywords:", ] self.generic_phrases = { "new platform", "platform", "company", "market", "markets", "system", "technology", "solution", "services", "service", "product", "products", "tool", "tools", } self.stopwords = { "a", "an", "the", "and", "or", "of", "for", "to", "in", "on", "with", "by", "at", "from", "into", "over", "under", "through", "across", "is", "are", "was", "were", "be", "been", "being", "this", "that", "these", "those", "it", "its", "their", "new", "latest" } def _normalize_space(self, text: str) -> str: return " ".join(text.split()).strip() def _normalize_phrase(self, text: str) -> str: text = self._normalize_space(text) text = text.strip(" ,.;:-_") return text def _phrase_tokens(self, text: str) -> List[str]: return re.findall(r"[A-Za-z0-9][A-Za-z0-9\-+/.]*", text.lower()) def _contains_instruction_leakage(self, phrase_lower: str) -> bool: return any(phrase_lower.startswith(prefix) for prefix in self.bad_prefixes) def _looks_sentence_like(self, phrase: str) -> bool: lower = phrase.lower() markers = [" and ", " because ", " which ", " where ", " when ", " while ", " after ", " before "] if any(m in lower for m in markers) and len(phrase.split()) > 4: return True if phrase.endswith("."): return True return False def _is_too_generic(self, phrase: str) -> bool: lower = phrase.lower() if lower in self.generic_phrases: return True tokens = self._phrase_tokens(lower) if len(tokens) == 1 and tokens[0] in self.generic_phrases: return True # phrases like "new platform" or "new system" if len(tokens) == 2 and tokens[0] in {"new", "latest"} and tokens[1] in self.generic_phrases: return True return False def _jaccard(self, a: List[str], b: List[str]) -> float: sa, sb = set(a), set(b) if not sa or not sb: return 0.0 return len(sa & sb) / len(sa | sb) def _text_coverage_score(self, phrase: str, source_text: str) -> float: """ Soft relevance score using literal presence and token overlap. Keeps semantically good present phrases near the top. """ phrase_lower = phrase.lower() source_lower = source_text.lower() score = 0.0 if phrase_lower in source_lower: score += 4.0 phrase_tokens = self._phrase_tokens(phrase) source_tokens = self._phrase_tokens(source_text) if not phrase_tokens: return 0.0 overlap = len(set(phrase_tokens) & set(source_tokens)) score += overlap * 1.25 score += self._jaccard(phrase_tokens, source_tokens) * 2.0 # prefer 2–3 word phrases slightly wc = len(phrase.split()) if wc == 2: score += 1.0 elif wc == 3: score += 0.75 elif wc == 1: score += 0.25 elif wc >= 5: score -= 1.0 # penalize generic lead words if phrase_tokens and phrase_tokens[0] in self.stopwords: score -= 0.75 return score def _parse_candidates(self, generated_texts: List[str], source_text: str, max_keyword_words: int) -> List[str]: source_lower = self._normalize_space(source_text.lower()) candidates: List[str] = [] for raw_text in generated_texts: parts = [self._normalize_phrase(p) for p in raw_text.split(";")] for part in parts: if not part: continue lower = part.lower() if self._contains_instruction_leakage(lower): continue if lower == source_lower: continue if len(lower) > 30 and lower in source_lower: # likely near-complete echo continue if self._looks_sentence_like(part): continue wc = len(part.split()) if wc == 0 or wc > max_keyword_words: continue if self._is_too_generic(part): continue candidates.append(part) return candidates def _dedupe_and_prune(self, phrases: List[str], source_text: str, top_k: int) -> List[Tuple[str, float]]: # First score scored: List[Tuple[str, float]] = [] seen_exact = set() for phrase in phrases: norm = phrase.lower() if norm in seen_exact: continue seen_exact.add(norm) score = self._text_coverage_score(phrase, source_text) if score > 0: scored.append((phrase, score)) # Sort best first scored.sort(key=lambda x: x[1], reverse=True) # Remove subsumed / near-duplicate phrases final_scored: List[Tuple[str, float]] = [] for phrase, score in scored: ptoks = self._phrase_tokens(phrase) pset = set(ptoks) should_skip = False for kept_phrase, kept_score in final_scored: ktoks = self._phrase_tokens(kept_phrase) kset = set(ktoks) # exact token subset of a better phrase -> drop shorter one if pset and pset.issubset(kset): should_skip = True break # heavy overlap and shorter/weaker -> drop jac = self._jaccard(ptoks, ktoks) if jac >= 0.6: if len(ptoks) <= len(ktoks) and score <= kept_score + 0.5: should_skip = True break if not should_skip: final_scored.append((phrase, round(score, 4))) if len(final_scored) >= top_k: break return final_scored def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: text = data.get("inputs") if text is None: return {"error": "Missing required field: inputs"} if not isinstance(text, str): return {"error": "The 'inputs' field must be a string"} parameters = data.get("parameters", {}) max_input_length = int(parameters.get("max_input_length", 1024)) max_new_tokens = int(parameters.get("max_new_tokens", 32)) num_beams = int(parameters.get("num_beams", 6)) num_return_sequences = int(parameters.get("num_return_sequences", 4)) do_sample = bool(parameters.get("do_sample", False)) temperature = float(parameters.get("temperature", 0.9)) top_p = float(parameters.get("top_p", 0.95)) no_repeat_ngram_size = int(parameters.get("no_repeat_ngram_size", 2)) max_keyword_words = int(parameters.get("max_keyword_words", 4)) top_k_keywords = int(parameters.get("top_k_keywords", 6)) return_scores = bool(parameters.get("return_scores", False)) if not do_sample: # beam search requires return_sequences <= beams num_return_sequences = min(num_return_sequences, num_beams) encoded = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=max_input_length, ) encoded = {k: v.to(self.device) for k, v in encoded.items()} generate_kwargs = { **encoded, "max_new_tokens": max_new_tokens, "num_beams": num_beams, "num_return_sequences": num_return_sequences, "do_sample": do_sample, "no_repeat_ngram_size": no_repeat_ngram_size, "early_stopping": True, } if do_sample: generate_kwargs["temperature"] = temperature generate_kwargs["top_p"] = top_p with torch.inference_mode(): output_ids = self.model.generate(**generate_kwargs) generated_texts = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) generated_texts = [self._normalize_space(t) for t in generated_texts if self._normalize_space(t)] candidates = self._parse_candidates( generated_texts=generated_texts, source_text=text, max_keyword_words=max_keyword_words, ) ranked = self._dedupe_and_prune( phrases=candidates, source_text=text, top_k=top_k_keywords, ) keywords = [phrase for phrase, _ in ranked] response: Dict[str, Any] = { "generated_texts": generated_texts, "keywords": keywords, } if return_scores: response["keyword_scores"] = [ {"keyword": phrase, "score": score} for phrase, score in ranked ] return response