| from typing import Any, Dict, List, Tuple |
| import math |
| import re |
|
|
| import torch |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| self.tokenizer = AutoTokenizer.from_pretrained(path) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(path) |
|
|
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| self.bad_prefixes = [ |
| "extract keyphrases:", |
| "extract keywords:", |
| "keyphrases:", |
| "keywords:", |
| ] |
|
|
| self.generic_phrases = { |
| "new platform", |
| "platform", |
| "company", |
| "market", |
| "markets", |
| "system", |
| "technology", |
| "solution", |
| "services", |
| "service", |
| "product", |
| "products", |
| "tool", |
| "tools", |
| } |
|
|
| self.stopwords = { |
| "a", "an", "the", "and", "or", "of", "for", "to", "in", "on", "with", |
| "by", "at", "from", "into", "over", "under", "through", "across", |
| "is", "are", "was", "were", "be", "been", "being", |
| "this", "that", "these", "those", "it", "its", "their", |
| "new", "latest" |
| } |
|
|
| def _normalize_space(self, text: str) -> str: |
| return " ".join(text.split()).strip() |
|
|
| def _normalize_phrase(self, text: str) -> str: |
| text = self._normalize_space(text) |
| text = text.strip(" ,.;:-_") |
| return text |
|
|
| def _phrase_tokens(self, text: str) -> List[str]: |
| return re.findall(r"[A-Za-z0-9][A-Za-z0-9\-+/.]*", text.lower()) |
|
|
| def _contains_instruction_leakage(self, phrase_lower: str) -> bool: |
| return any(phrase_lower.startswith(prefix) for prefix in self.bad_prefixes) |
|
|
| def _looks_sentence_like(self, phrase: str) -> bool: |
| lower = phrase.lower() |
| markers = [" and ", " because ", " which ", " where ", " when ", " while ", " after ", " before "] |
| if any(m in lower for m in markers) and len(phrase.split()) > 4: |
| return True |
| if phrase.endswith("."): |
| return True |
| return False |
|
|
| def _is_too_generic(self, phrase: str) -> bool: |
| lower = phrase.lower() |
| if lower in self.generic_phrases: |
| return True |
|
|
| tokens = self._phrase_tokens(lower) |
| if len(tokens) == 1 and tokens[0] in self.generic_phrases: |
| return True |
|
|
| |
| if len(tokens) == 2 and tokens[0] in {"new", "latest"} and tokens[1] in self.generic_phrases: |
| return True |
|
|
| return False |
|
|
| def _jaccard(self, a: List[str], b: List[str]) -> float: |
| sa, sb = set(a), set(b) |
| if not sa or not sb: |
| return 0.0 |
| return len(sa & sb) / len(sa | sb) |
|
|
| def _text_coverage_score(self, phrase: str, source_text: str) -> float: |
| """ |
| Soft relevance score using literal presence and token overlap. |
| Keeps semantically good present phrases near the top. |
| """ |
| phrase_lower = phrase.lower() |
| source_lower = source_text.lower() |
|
|
| score = 0.0 |
|
|
| if phrase_lower in source_lower: |
| score += 4.0 |
|
|
| phrase_tokens = self._phrase_tokens(phrase) |
| source_tokens = self._phrase_tokens(source_text) |
|
|
| if not phrase_tokens: |
| return 0.0 |
|
|
| overlap = len(set(phrase_tokens) & set(source_tokens)) |
| score += overlap * 1.25 |
| score += self._jaccard(phrase_tokens, source_tokens) * 2.0 |
|
|
| |
| wc = len(phrase.split()) |
| if wc == 2: |
| score += 1.0 |
| elif wc == 3: |
| score += 0.75 |
| elif wc == 1: |
| score += 0.25 |
| elif wc >= 5: |
| score -= 1.0 |
|
|
| |
| if phrase_tokens and phrase_tokens[0] in self.stopwords: |
| score -= 0.75 |
|
|
| return score |
|
|
| def _parse_candidates(self, generated_texts: List[str], source_text: str, max_keyword_words: int) -> List[str]: |
| source_lower = self._normalize_space(source_text.lower()) |
| candidates: List[str] = [] |
|
|
| for raw_text in generated_texts: |
| parts = [self._normalize_phrase(p) for p in raw_text.split(";")] |
| for part in parts: |
| if not part: |
| continue |
|
|
| lower = part.lower() |
|
|
| if self._contains_instruction_leakage(lower): |
| continue |
|
|
| if lower == source_lower: |
| continue |
|
|
| if len(lower) > 30 and lower in source_lower: |
| |
| continue |
|
|
| if self._looks_sentence_like(part): |
| continue |
|
|
| wc = len(part.split()) |
| if wc == 0 or wc > max_keyword_words: |
| continue |
|
|
| if self._is_too_generic(part): |
| continue |
|
|
| candidates.append(part) |
|
|
| return candidates |
|
|
| def _dedupe_and_prune(self, phrases: List[str], source_text: str, top_k: int) -> List[Tuple[str, float]]: |
| |
| scored: List[Tuple[str, float]] = [] |
| seen_exact = set() |
|
|
| for phrase in phrases: |
| norm = phrase.lower() |
| if norm in seen_exact: |
| continue |
| seen_exact.add(norm) |
|
|
| score = self._text_coverage_score(phrase, source_text) |
| if score > 0: |
| scored.append((phrase, score)) |
|
|
| |
| scored.sort(key=lambda x: x[1], reverse=True) |
|
|
| |
| final_scored: List[Tuple[str, float]] = [] |
| for phrase, score in scored: |
| ptoks = self._phrase_tokens(phrase) |
| pset = set(ptoks) |
|
|
| should_skip = False |
| for kept_phrase, kept_score in final_scored: |
| ktoks = self._phrase_tokens(kept_phrase) |
| kset = set(ktoks) |
|
|
| |
| if pset and pset.issubset(kset): |
| should_skip = True |
| break |
|
|
| |
| jac = self._jaccard(ptoks, ktoks) |
| if jac >= 0.6: |
| if len(ptoks) <= len(ktoks) and score <= kept_score + 0.5: |
| should_skip = True |
| break |
|
|
| if not should_skip: |
| final_scored.append((phrase, round(score, 4))) |
|
|
| if len(final_scored) >= top_k: |
| break |
|
|
| return final_scored |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| text = data.get("inputs") |
| if text is None: |
| return {"error": "Missing required field: inputs"} |
|
|
| if not isinstance(text, str): |
| return {"error": "The 'inputs' field must be a string"} |
|
|
| parameters = data.get("parameters", {}) |
|
|
| max_input_length = int(parameters.get("max_input_length", 1024)) |
| max_new_tokens = int(parameters.get("max_new_tokens", 32)) |
| num_beams = int(parameters.get("num_beams", 6)) |
| num_return_sequences = int(parameters.get("num_return_sequences", 4)) |
| do_sample = bool(parameters.get("do_sample", False)) |
| temperature = float(parameters.get("temperature", 0.9)) |
| top_p = float(parameters.get("top_p", 0.95)) |
| no_repeat_ngram_size = int(parameters.get("no_repeat_ngram_size", 2)) |
| max_keyword_words = int(parameters.get("max_keyword_words", 4)) |
| top_k_keywords = int(parameters.get("top_k_keywords", 6)) |
| return_scores = bool(parameters.get("return_scores", False)) |
|
|
| if not do_sample: |
| |
| num_return_sequences = min(num_return_sequences, num_beams) |
|
|
| encoded = self.tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=max_input_length, |
| ) |
| encoded = {k: v.to(self.device) for k, v in encoded.items()} |
|
|
| generate_kwargs = { |
| **encoded, |
| "max_new_tokens": max_new_tokens, |
| "num_beams": num_beams, |
| "num_return_sequences": num_return_sequences, |
| "do_sample": do_sample, |
| "no_repeat_ngram_size": no_repeat_ngram_size, |
| "early_stopping": True, |
| } |
|
|
| if do_sample: |
| generate_kwargs["temperature"] = temperature |
| generate_kwargs["top_p"] = top_p |
|
|
| with torch.inference_mode(): |
| output_ids = self.model.generate(**generate_kwargs) |
|
|
| generated_texts = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
| generated_texts = [self._normalize_space(t) for t in generated_texts if self._normalize_space(t)] |
|
|
| candidates = self._parse_candidates( |
| generated_texts=generated_texts, |
| source_text=text, |
| max_keyword_words=max_keyword_words, |
| ) |
|
|
| ranked = self._dedupe_and_prune( |
| phrases=candidates, |
| source_text=text, |
| top_k=top_k_keywords, |
| ) |
|
|
| keywords = [phrase for phrase, _ in ranked] |
|
|
| response: Dict[str, Any] = { |
| "generated_texts": generated_texts, |
| "keywords": keywords, |
| } |
|
|
| if return_scores: |
| response["keyword_scores"] = [ |
| {"keyword": phrase, "score": score} |
| for phrase, score in ranked |
| ] |
|
|
| return response |