Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Context building for the Text2SPARQL repair pipeline. | |
| Builds a compact context package from question analysis and KG profile. | |
| Provides higher-quality grounding hints through hybrid retrieval: | |
| - entities from live endpoint retrieval + lexical reranking | |
| - relations/classes from static schema profile + stronger lexical scoring | |
| """ | |
| from __future__ import annotations | |
| import html | |
| import logging | |
| import re | |
| import urllib.parse | |
| import urllib.request | |
| from functools import lru_cache | |
| from difflib import SequenceMatcher | |
| from typing import Any | |
| from SPARQLWrapper import JSON, POST, SPARQLWrapper | |
| from .config import RuntimeConfig | |
| from .models import ContextPackage, DatasetConfig, QueryRequest | |
| logger = logging.getLogger(__name__) | |
| # ββ Answer type inference patterns βββββββββββββββββββββββββββββββ | |
| _YES_NO_PATTERNS = [ | |
| r"^(is|are|was|were|do|does|did|has|have|had|can|could|will|would|should)\b", | |
| r"^(isn't|aren't|wasn't|weren't|don't|doesn't|didn't)\b", | |
| ] | |
| _COUNT_PATTERNS = [ | |
| r"\bhow\s+many\b", | |
| r"\bnumber\s+of\b", | |
| r"\bcount\s+of\b", | |
| r"\btotal\s+(number|count)\b", | |
| ] | |
| _STOP_WORDS = { | |
| "who", "what", "which", "when", "how", "many", "the", "a", "an", "is", "are", | |
| "was", "were", "do", "does", "did", "of", "in", "on", "at", "to", "for", | |
| "with", "by", "from", "and", "or", "but", "about", "be", "been", "being", | |
| "all", "give", "me", "there", "any", "into", "as", "same", "than", "both", | |
| "first", "last", "most", "least", "total", | |
| } | |
| _QUESTION_LEAD_WORDS = { | |
| "what", "which", "who", "where", "when", "how", "give", "list", "name", | |
| } | |
| _ENTITY_CONNECTORS = {"of", "the", "and", "de", "la", "le", "van", "von", "du", "del"} | |
| _LEADING_ROLE_WORDS = { | |
| "ceo", "president", "mayor", "author", "director", "capital", "population", | |
| "king", "queen", "city", "country", "state", "province", "river", "mountain", | |
| "book", "books", "movie", "movies", "film", "album", "song", "band", | |
| } | |
| _COMMON_LABEL_PREDICATES = [ | |
| "http://www.w3.org/2000/01/rdf-schema#label", | |
| "http://xmlns.com/foaf/0.1/name", | |
| "http://schema.org/name", | |
| "http://www.w3.org/2004/02/skos/core#prefLabel", | |
| "http://www.w3.org/2004/02/skos/core#altLabel", | |
| ] | |
| _COMMON_DESCRIPTION_PREDICATES = [ | |
| "http://www.w3.org/2000/01/rdf-schema#comment", | |
| "http://schema.org/description", | |
| "http://dbpedia.org/ontology/abstract", | |
| ] | |
| _DBPEDIA_REDIRECT_PREDICATE = "http://dbpedia.org/ontology/wikiPageRedirects" | |
| _RELATION_SYNONYM_MAP = { | |
| "population": ["people", "inhabitants", "residents"], | |
| "mayor": ["leader", "head", "governor"], | |
| "directed": ["director", "direct"], | |
| "wrote": ["author", "written", "write"], | |
| "books": ["book", "written"], | |
| "longest": ["length", "long"], | |
| "capital": ["capital city", "seat"], | |
| } | |
| _TYPE_HINT_KEYWORDS = { | |
| "person": ["person", "actor", "actress", "author", "writer", "director", "politician", "scientist", "musician"], | |
| "place": ["place", "city", "country", "state", "river", "mountain", "village", "town", "location"], | |
| "date": ["date", "year", "time", "period"], | |
| "organization": ["organization", "company", "band", "team", "agency", "airline", "alliance"], | |
| } | |
| _EMBEDDING_MODEL = None | |
| _EMBEDDING_TOKENIZER = None | |
| _EMBEDDING_MODEL_NAME = None | |
| def infer_answer_type(question: str) -> str: | |
| """Infer the expected SPARQL query form from the question.""" | |
| q_lower = question.strip().lower() | |
| for pattern in _COUNT_PATTERNS: | |
| if re.search(pattern, q_lower): | |
| return "count" | |
| for pattern in _YES_NO_PATTERNS: | |
| if re.search(pattern, q_lower): | |
| return "ask" | |
| return "select" | |
| def _tokenize(text: str, remove_stop_words: bool = False) -> list[str]: | |
| """Simple word tokenization for matching.""" | |
| tokens = re.findall(r"[a-zA-Z0-9]+", text.lower()) | |
| if remove_stop_words: | |
| return [t for t in tokens if t not in _STOP_WORDS] | |
| return tokens | |
| def _normalize_text(text: str) -> str: | |
| """Normalize labels/surface forms for lexical matching.""" | |
| if not text: | |
| return "" | |
| text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text) | |
| text = text.replace("_", " ") | |
| text = re.sub(r"[\(\)\[\]\{\},;:/\\|]+", " ", text) | |
| text = re.sub(r"\s+", " ", text.lower()).strip() | |
| return text | |
| def _fuzzy_match_score(query_tokens: list[str], label: str) -> float: | |
| """Compute a simple overlap score between query tokens and a label.""" | |
| label_tokens = set(_tokenize(_normalize_text(label))) | |
| if not label_tokens: | |
| return 0.0 | |
| query_set = set(query_tokens) | |
| overlap = query_set & label_tokens | |
| if not overlap: | |
| return 0.0 | |
| return len(overlap) / (len(label_tokens) + len(query_set) - len(overlap)) | |
| def _sequence_similarity(left: str, right: str) -> float: | |
| """Cheap character-level similarity for tie-breaking.""" | |
| if not left or not right: | |
| return 0.0 | |
| return SequenceMatcher(a=_normalize_text(left), b=_normalize_text(right)).ratio() | |
| def _split_uri_tail(uri: str) -> str: | |
| """Extract a readable tail from a URI for lexical comparison.""" | |
| if not uri: | |
| return "" | |
| tail = uri.rsplit("/", 1)[-1].rsplit("#", 1)[-1] | |
| return _normalize_text(tail) | |
| def _strip_highlight_markup(text: str) -> str: | |
| """Remove simple HTML-ish lookup markup from labels/comments.""" | |
| if not text: | |
| return "" | |
| text = re.sub(r"</?B>", "", text) | |
| return html.unescape(text) | |
| def _extract_entity_mentions(question: str) -> list[str]: | |
| """Extract likely entity surface forms from the question. | |
| This is intentionally heuristic. It targets common DBpedia mention shapes: | |
| - quoted spans | |
| - contiguous capitalized spans | |
| - capitalized spans with light connectors such as "of" or commas | |
| """ | |
| mentions: list[str] = [] | |
| # Quoted spans are usually strong mentions. | |
| for match in re.findall(r'"([^"]+)"|\'([^\']+)\'', question): | |
| mention = next((m for m in match if m), "").strip() | |
| if mention: | |
| mentions.append(mention) | |
| tokens = re.findall(r"[A-Za-z0-9][A-Za-z0-9._'-]*|,", question) | |
| current: list[str] = [] | |
| def flush() -> None: | |
| nonlocal current | |
| if not current: | |
| return | |
| mention = " ".join(current).replace(" ,", ",").strip(" ,.;:!?") | |
| mention = _postprocess_entity_mention(mention) | |
| if mention: | |
| norm_tokens = _normalize_text(mention).split() | |
| first = norm_tokens[0] if norm_tokens else "" | |
| if first not in _QUESTION_LEAD_WORDS and not ( | |
| len(norm_tokens) == 1 and first in _STOP_WORDS | |
| ): | |
| mentions.append(mention) | |
| current = [] | |
| for tok in tokens: | |
| clean = tok.strip() | |
| lower = clean.lower() | |
| is_number = clean.isdigit() | |
| is_cap = bool(re.match(r"^[A-Z][A-Za-z0-9._'-]*$", clean)) | |
| if clean == "," and current: | |
| current.append(clean) | |
| continue | |
| if is_cap or is_number: | |
| current.append(clean) | |
| continue | |
| if current and lower in _ENTITY_CONNECTORS: | |
| current.append(clean) | |
| continue | |
| flush() | |
| flush() | |
| # Deduplicate while keeping longer, more informative spans first. | |
| unique: list[str] = [] | |
| seen: set[str] = set() | |
| for mention in sorted(mentions, key=lambda x: (-len(_tokenize(x)), -len(x))): | |
| norm = _normalize_text(mention) | |
| if not norm or norm in seen: | |
| continue | |
| seen.add(norm) | |
| unique.append(mention) | |
| return unique[:6] | |
| def _postprocess_entity_mention(mention: str) -> str: | |
| """Trim obvious non-entity lead-ins from a detected mention span.""" | |
| tokens = mention.split() | |
| if len(tokens) >= 3: | |
| first = tokens[0] | |
| second = tokens[1].lower() | |
| if ( | |
| re.fullmatch(r"[A-Z]{2,}", first) | |
| and second == "of" | |
| ): | |
| return " ".join(tokens[2:]).strip() | |
| if first.lower() in _LEADING_ROLE_WORDS and second == "of": | |
| return " ".join(tokens[2:]).strip() | |
| return mention | |
| def _expand_entity_surface_forms(mentions: list[str]) -> list[str]: | |
| """Expand detected mentions into retrieval-friendly surface-form variants. | |
| The goal is still lightweight lexical retrieval, but with more robust forms | |
| for exact/full-text lookup: | |
| - drop leading articles for title-like mentions | |
| - remove commas and parentheticals | |
| - keep comma-split head spans as a fallback disambiguation form | |
| """ | |
| variants: list[str] = [] | |
| for mention in mentions: | |
| cleaned = mention.strip(" ,.;:!?") | |
| if not cleaned: | |
| continue | |
| candidates = [cleaned] | |
| no_parens = re.sub(r"\s*\([^)]*\)", "", cleaned).strip() | |
| if no_parens and no_parens != cleaned: | |
| candidates.append(no_parens) | |
| tokens = cleaned.split() | |
| if len(tokens) >= 2 and tokens[0].lower() in {"the", "a", "an"}: | |
| candidates.append(" ".join(tokens[1:])) | |
| if "," in cleaned: | |
| comma_flat = re.sub(r"\s*,\s*", " ", cleaned).strip() | |
| if comma_flat: | |
| candidates.append(comma_flat) | |
| head = cleaned.split(",", 1)[0].strip() | |
| if head: | |
| candidates.append(head) | |
| for candidate in candidates: | |
| candidate = re.sub(r"\s+", " ", candidate).strip(" ,.;:!?") | |
| if candidate: | |
| variants.append(candidate) | |
| unique: list[str] = [] | |
| seen: set[str] = set() | |
| for variant in sorted(variants, key=lambda x: (-len(_tokenize(x)), -len(x))): | |
| norm = _normalize_text(variant) | |
| if not norm or norm in seen: | |
| continue | |
| seen.add(norm) | |
| unique.append(variant) | |
| return unique[:10] | |
| def _build_relation_focus(question: str, entity_mentions: list[str]) -> str: | |
| """Create a question fragment more focused on relation/property semantics.""" | |
| focused = question | |
| for mention in entity_mentions: | |
| if mention: | |
| focused = re.sub(re.escape(mention), " ", focused, flags=re.IGNORECASE) | |
| focused = re.sub(r"\s+", " ", focused).strip() | |
| return focused | |
| def _run_sparql_json(query: str, endpoint_url: str, timeout_sec: int = 8) -> dict[str, Any] | None: | |
| """Execute a small SPARQL JSON query and return the decoded payload.""" | |
| try: | |
| sparql = SPARQLWrapper(endpoint_url) | |
| sparql.setQuery(query) | |
| sparql.setReturnFormat(JSON) | |
| sparql.setTimeout(timeout_sec) | |
| sparql.setMethod(POST) | |
| return sparql.query().convert() | |
| except Exception as exc: | |
| logger.warning("SPARQL lookup failed: %s", exc) | |
| return None | |
| def _http_json( | |
| url: str, | |
| *, | |
| headers: dict[str, str] | None = None, | |
| data: bytes | None = None, | |
| timeout_sec: int = 10, | |
| ) -> dict[str, Any] | None: | |
| """Execute a small HTTP request and decode JSON.""" | |
| try: | |
| request = urllib.request.Request( | |
| url, | |
| data=data, | |
| headers={ | |
| "User-Agent": "Mozilla/5.0", | |
| **(headers or {}), | |
| }, | |
| ) | |
| with urllib.request.urlopen(request, timeout=timeout_sec) as response: | |
| import json | |
| return json.load(response) | |
| except Exception as exc: | |
| logger.warning("HTTP JSON lookup failed: %s", exc) | |
| return None | |
| def _init_embedding_backend(model_name: str): | |
| """Lazy-load a lightweight sentence embedding model via transformers.""" | |
| global _EMBEDDING_MODEL, _EMBEDDING_TOKENIZER, _EMBEDDING_MODEL_NAME | |
| if _EMBEDDING_MODEL is not None and _EMBEDDING_MODEL_NAME == model_name: | |
| return _EMBEDDING_TOKENIZER, _EMBEDDING_MODEL | |
| from transformers import AutoModel, AutoTokenizer | |
| logger.info("Loading embedding model for linker reranking: %s", model_name) | |
| _EMBEDDING_TOKENIZER = AutoTokenizer.from_pretrained(model_name) | |
| _EMBEDDING_MODEL = AutoModel.from_pretrained(model_name) | |
| _EMBEDDING_MODEL.eval() | |
| _EMBEDDING_MODEL_NAME = model_name | |
| return _EMBEDDING_TOKENIZER, _EMBEDDING_MODEL | |
| def _mean_pool_embeddings(model_output, attention_mask): | |
| import torch | |
| token_embeddings = model_output.last_hidden_state | |
| mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| summed = (token_embeddings * mask).sum(dim=1) | |
| counts = mask.sum(dim=1).clamp(min=1e-9) | |
| return summed / counts | |
| def _embed_single_text(text: str, model_name: str) -> tuple[float, ...]: | |
| """Embed one text string and cache the vector.""" | |
| import torch | |
| tokenizer, model = _init_embedding_backend(model_name) | |
| encoded = tokenizer( | |
| [text], | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=256, | |
| ) | |
| with torch.no_grad(): | |
| output = model(**encoded) | |
| pooled = _mean_pool_embeddings(output, encoded["attention_mask"]) | |
| pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) | |
| vector = pooled[0].cpu().tolist() | |
| return tuple(float(x) for x in vector) | |
| def _cosine_from_tuples(left: tuple[float, ...], right: tuple[float, ...]) -> float: | |
| if not left or not right or len(left) != len(right): | |
| return 0.0 | |
| return float(sum(a * b for a, b in zip(left, right))) | |
| def _label_predicates_values() -> str: | |
| return " ".join(f"<{p}>" for p in _COMMON_LABEL_PREDICATES) | |
| def _description_predicates_values() -> str: | |
| return " ".join(f"<{p}>" for p in _COMMON_DESCRIPTION_PREDICATES) | |
| def _build_entity_lookup_query(surface_form: str, top_k: int) -> str: | |
| """Build an exact-label entity lookup query.""" | |
| safe_phrase = surface_form.replace('"', "").strip()[:160] | |
| return f""" | |
| PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> | |
| SELECT DISTINCT ?uri (SAMPLE(?label) AS ?label) (SAMPLE(?comment) AS ?comment) WHERE {{ | |
| VALUES ?lp {{ {_label_predicates_values()} }} | |
| VALUES ?exactLabel {{ "{safe_phrase}"@en }} | |
| ?uri ?lp ?exactLabel . | |
| BIND(?exactLabel AS ?label) | |
| OPTIONAL {{ | |
| VALUES ?dp {{ {_description_predicates_values()} }} | |
| ?uri ?dp ?comment . | |
| FILTER (LANG(?comment) = 'en') | |
| }} | |
| }} | |
| LIMIT {max(top_k * 6, 20)} | |
| """ | |
| def _build_entity_fulltext_query(surface_form: str, top_k: int) -> str | None: | |
| """Build a faster full-text query for approximate entity retrieval.""" | |
| tokens = [t for t in _tokenize(surface_form, remove_stop_words=True) if len(t) >= 4] | |
| safe_tokens = [t.replace("'", "")[:40] for t in tokens[:4]] | |
| if not safe_tokens: | |
| return None | |
| fulltext = " OR ".join(f"'{tok}*'" for tok in safe_tokens) | |
| return f""" | |
| PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> | |
| SELECT DISTINCT ?uri (SAMPLE(?label) AS ?label) (SAMPLE(?comment) AS ?comment) WHERE {{ | |
| VALUES ?lp {{ {_label_predicates_values()} }} | |
| ?uri ?lp ?label . | |
| FILTER (LANG(?label) = 'en') | |
| ?label bif:contains "{fulltext}" . | |
| OPTIONAL {{ | |
| VALUES ?dp {{ {_description_predicates_values()} }} | |
| ?uri ?dp ?comment . | |
| FILTER (LANG(?comment) = 'en') | |
| }} | |
| }} | |
| LIMIT {max(top_k * 8, 30)} | |
| """ | |
| def _build_entity_redirect_query(surface_form: str, top_k: int) -> str: | |
| """Build a redirect-aware entity lookup query for DBpedia.""" | |
| safe_phrase = surface_form.replace('"', "").strip()[:160] | |
| return f""" | |
| PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> | |
| SELECT DISTINCT ?uri (SAMPLE(?labelResolved) AS ?label) (SAMPLE(?comment) AS ?comment) WHERE {{ | |
| VALUES ?lp {{ {_label_predicates_values()} }} | |
| VALUES ?exactLabel {{ "{safe_phrase}"@en }} | |
| ?redirect ?lp ?exactLabel . | |
| ?redirect <{_DBPEDIA_REDIRECT_PREDICATE}> ?uri . | |
| OPTIONAL {{ | |
| VALUES ?lp2 {{ {_label_predicates_values()} }} | |
| ?uri ?lp2 ?labelResolved . | |
| FILTER (LANG(?labelResolved) = 'en') | |
| }} | |
| BIND(COALESCE(?labelResolved, ?exactLabel) AS ?labelResolved) | |
| OPTIONAL {{ | |
| VALUES ?dp {{ {_description_predicates_values()} }} | |
| ?uri ?dp ?comment . | |
| FILTER (LANG(?comment) = 'en') | |
| }} | |
| }} | |
| LIMIT {max(top_k * 4, 15)} | |
| """ | |
| def _infer_entity_type_preference(question: str) -> str | None: | |
| """Infer a coarse answer/entity preference from the question wording.""" | |
| q = question.strip().lower() | |
| if q.startswith("who"): | |
| return "person" | |
| if q.startswith("where"): | |
| return "place" | |
| if q.startswith("when"): | |
| return "date" | |
| if "airline" in q or "company" in q or "band" in q: | |
| return "organization" | |
| return None | |
| def _type_aware_bonus(question: str, candidate: dict[str, Any]) -> float: | |
| """Give a small bonus when candidate text matches question type expectations.""" | |
| preference = _infer_entity_type_preference(question) | |
| if not preference: | |
| return 0.0 | |
| haystack = " ".join( | |
| [ | |
| candidate.get("label", ""), | |
| candidate.get("comment", ""), | |
| _split_uri_tail(candidate.get("uri", "")), | |
| ] | |
| ).lower() | |
| keywords = _TYPE_HINT_KEYWORDS.get(preference, []) | |
| if any(keyword in haystack for keyword in keywords): | |
| return 0.8 | |
| return 0.0 | |
| def _expand_schema_focus(question_focus: str, linker_mode: str) -> str: | |
| """Expand relation/class focus text for stronger lexical matching.""" | |
| if linker_mode != "internal_max": | |
| return question_focus | |
| tokens = _tokenize(question_focus, remove_stop_words=True) | |
| extras: list[str] = [] | |
| for token in tokens: | |
| extras.extend(_RELATION_SYNONYM_MAP.get(token, [])) | |
| if not extras: | |
| return question_focus | |
| return f"{question_focus} {' '.join(extras)}".strip() | |
| def _score_entity_candidate( | |
| question: str, | |
| surface_form: str, | |
| candidate: dict[str, Any], | |
| linker_mode: str, | |
| ) -> float: | |
| """Rerank an entity candidate using label, URI-tail, and comment overlap.""" | |
| q_tokens = _tokenize(question, remove_stop_words=True) | |
| surface_tokens = _tokenize(surface_form, remove_stop_words=True) | |
| label = candidate.get("label", "") | |
| comment = candidate.get("comment", "") | |
| uri = candidate.get("uri", "") | |
| exact = 1.0 if _normalize_text(label) == _normalize_text(surface_form) else 0.0 | |
| contains = 1.0 if _normalize_text(surface_form) in _normalize_text(label) else 0.0 | |
| surface_overlap = _fuzzy_match_score(surface_tokens or q_tokens, label) | |
| question_overlap = _fuzzy_match_score(q_tokens, label) | |
| tail_similarity = _sequence_similarity(surface_form, _split_uri_tail(uri)) | |
| comment_overlap = _fuzzy_match_score(q_tokens, comment) | |
| score = ( | |
| 3.0 * exact | |
| + 1.75 * contains | |
| + 1.5 * surface_overlap | |
| + 0.9 * question_overlap | |
| + 0.9 * tail_similarity | |
| + 0.8 * comment_overlap | |
| ) | |
| if "http://dbpedia.org/resource/" in uri: | |
| score += 0.1 | |
| if linker_mode == "internal_max": | |
| score += _type_aware_bonus(question, candidate) | |
| tail = _split_uri_tail(uri) | |
| if any(flag in tail for flag in ("film", "album", "disambiguation")): | |
| score -= 0.3 | |
| return round(score, 4) | |
| def retrieve_entities_hybrid( | |
| question: str, | |
| endpoint_url: str, | |
| top_k: int, | |
| linker_mode: str = "internal_min", | |
| ) -> list[dict]: | |
| """Retrieve and rerank entity candidates from a live endpoint.""" | |
| mentions = _extract_entity_mentions(question) | |
| surface_forms = _expand_entity_surface_forms(mentions) | |
| if not surface_forms: | |
| fallback = " ".join(_tokenize(question, remove_stop_words=True)[:5]) | |
| surface_forms = [fallback] if fallback else [] | |
| merged: dict[str, dict[str, Any]] = {} | |
| for mention in surface_forms: | |
| queries = [_build_entity_lookup_query(mention, top_k)] | |
| fuzzy_query = _build_entity_fulltext_query(mention, top_k) | |
| if fuzzy_query is not None: | |
| queries.append(fuzzy_query) | |
| if linker_mode == "internal_max": | |
| queries.append(_build_entity_redirect_query(mention, top_k)) | |
| for idx, query in enumerate(queries): | |
| raw = _run_sparql_json(query, endpoint_url, timeout_sec=4 if idx == 0 else 5) | |
| if not raw: | |
| continue | |
| for row in raw.get("results", {}).get("bindings", []): | |
| uri = row.get("uri", {}).get("value", "") | |
| label = row.get("label", {}).get("value", "") | |
| comment = row.get("comment", {}).get("value", "") | |
| if not uri or not label: | |
| continue | |
| candidate = merged.get(uri, { | |
| "uri": uri, | |
| "label": label, | |
| "comment": comment, | |
| "matched_mentions": [], | |
| "score": 0.0, | |
| }) | |
| if mention not in candidate["matched_mentions"]: | |
| candidate["matched_mentions"].append(mention) | |
| if comment and not candidate.get("comment"): | |
| candidate["comment"] = comment | |
| score = _score_entity_candidate(question, mention, candidate, linker_mode) | |
| candidate["score"] = max(candidate["score"], score) | |
| merged[uri] = candidate | |
| ranked = sorted( | |
| merged.values(), | |
| key=lambda x: (-x["score"], len(_normalize_text(x.get("label", "")))), | |
| ) | |
| return ranked[:top_k] | |
| def retrieve_entities_dbpedia_lookup( | |
| question: str, | |
| top_k: int, | |
| lookup_url: str, | |
| ) -> list[dict]: | |
| """Retrieve DBpedia entities using the official Lookup service.""" | |
| mentions = _expand_entity_surface_forms(_extract_entity_mentions(question)) | |
| if not mentions: | |
| fallback = " ".join(_tokenize(question, remove_stop_words=True)[:5]) | |
| mentions = [fallback] if fallback else [] | |
| merged: dict[str, dict[str, Any]] = {} | |
| for mention in mentions[:6]: | |
| params = urllib.parse.urlencode( | |
| { | |
| "format": "json", | |
| "query": mention, | |
| "maxResults": str(max(top_k * 5, 10)), | |
| } | |
| ) | |
| payload = _http_json( | |
| f"{lookup_url}?{params}", | |
| headers={"Accept": "application/json"}, | |
| timeout_sec=12, | |
| ) | |
| if not payload: | |
| continue | |
| for row in payload.get("docs", []): | |
| uri = next(iter(row.get("resource", [])), "") or next(iter(row.get("id", [])), "") | |
| label = _strip_highlight_markup(next(iter(row.get("label", [])), "")) | |
| comment = _strip_highlight_markup(next(iter(row.get("comment", [])), "")) | |
| score_raw = next(iter(row.get("score", [])), "0") | |
| if not uri: | |
| continue | |
| candidate = merged.get( | |
| uri, | |
| { | |
| "uri": uri, | |
| "label": label or _split_uri_tail(uri), | |
| "comment": comment, | |
| "score": 0.0, | |
| }, | |
| ) | |
| lexical = _score_entity_candidate(question, mention, candidate, "internal_max") | |
| try: | |
| backend_score = float(score_raw) | |
| except Exception: | |
| backend_score = 0.0 | |
| candidate["score"] = max(candidate["score"], lexical + min(backend_score / 10000.0, 2.0)) | |
| merged[uri] = candidate | |
| ranked = sorted(merged.values(), key=lambda item: (-item["score"], item.get("label", ""))) | |
| return ranked[:top_k] | |
| def retrieve_entities_dbpedia_spotlight( | |
| question: str, | |
| top_k: int, | |
| spotlight_base_url: str, | |
| language: str | None, | |
| ) -> list[dict]: | |
| """Retrieve DBpedia entities using the official Spotlight API.""" | |
| lang = (language or "en").lower() | |
| annotate_url = f"{spotlight_base_url.rstrip('/')}/{lang}/annotate" | |
| payload = _http_json( | |
| annotate_url, | |
| headers={"Accept": "application/json"}, | |
| data=urllib.parse.urlencode( | |
| {"text": question, "confidence": "0.35"} | |
| ).encode("utf-8"), | |
| timeout_sec=15, | |
| ) | |
| if not payload: | |
| return [] | |
| resources = payload.get("Resources", []) | |
| if isinstance(resources, dict): | |
| resources = [resources] | |
| merged: dict[str, dict[str, Any]] = {} | |
| for row in resources: | |
| uri = row.get("@URI", "") | |
| surface_form = row.get("@surfaceForm", "") | |
| if not uri: | |
| continue | |
| candidate = merged.get( | |
| uri, | |
| { | |
| "uri": uri, | |
| "label": surface_form or _split_uri_tail(uri), | |
| "comment": "", | |
| "score": 0.0, | |
| }, | |
| ) | |
| lexical = _score_entity_candidate(question, surface_form or _split_uri_tail(uri), candidate, "internal_max") | |
| try: | |
| similarity = float(row.get("@similarityScore", "0") or 0.0) | |
| except Exception: | |
| similarity = 0.0 | |
| try: | |
| support = float(row.get("@support", "0") or 0.0) | |
| except Exception: | |
| support = 0.0 | |
| candidate["score"] = max(candidate["score"], lexical + 1.5 * similarity + min(support / 10000.0, 1.0)) | |
| merged[uri] = candidate | |
| ranked = sorted(merged.values(), key=lambda item: (-item["score"], item.get("label", ""))) | |
| return ranked[:top_k] | |
| def retrieve_entities_embedding_rerank( | |
| question: str, | |
| endpoint_url: str, | |
| top_k: int, | |
| model_name: str, | |
| candidate_pool_size: int, | |
| ) -> list[dict]: | |
| """Hybrid entity linker: lexical retrieval first, embeddings for reranking.""" | |
| pool_size = max(candidate_pool_size, top_k, 8) | |
| lexical_candidates = retrieve_entities_hybrid( | |
| question, | |
| endpoint_url, | |
| pool_size, | |
| linker_mode="internal_max", | |
| ) | |
| if not lexical_candidates: | |
| return [] | |
| mentions = _expand_entity_surface_forms(_extract_entity_mentions(question)) | |
| query_views = [question] + mentions | |
| query_vectors = [_embed_single_text(text, model_name) for text in query_views if text.strip()] | |
| reranked: list[dict[str, Any]] = [] | |
| for candidate in lexical_candidates: | |
| candidate_text = " ".join( | |
| [ | |
| candidate.get("label", ""), | |
| candidate.get("comment", ""), | |
| _split_uri_tail(candidate.get("uri", "")), | |
| ] | |
| ).strip() | |
| if not candidate_text: | |
| candidate_text = candidate.get("uri", "") | |
| candidate_vector = _embed_single_text(candidate_text, model_name) | |
| semantic_similarity = max( | |
| (_cosine_from_tuples(query_vector, candidate_vector) for query_vector in query_vectors), | |
| default=0.0, | |
| ) | |
| candidate_copy = dict(candidate) | |
| candidate_copy["lexical_score"] = candidate.get("score", 0.0) | |
| candidate_copy["semantic_similarity"] = round(semantic_similarity, 4) | |
| candidate_copy["score"] = round( | |
| 0.45 * candidate.get("score", 0.0) + 4.0 * semantic_similarity, | |
| 4, | |
| ) | |
| reranked.append(candidate_copy) | |
| reranked.sort(key=lambda item: (-item["score"], item.get("label", ""))) | |
| return reranked[:top_k] | |
| def retrieve_entity_candidates( | |
| question: str, | |
| kg_profile: dict, | |
| top_k: int, | |
| endpoint_url: str | None = None, | |
| linker_mode: str = "internal_min", | |
| lookup_url: str | None = None, | |
| spotlight_base_url: str | None = None, | |
| language: str | None = None, | |
| embedding_model_name: str | None = None, | |
| embedding_candidate_pool_size: int = 16, | |
| ) -> list[dict]: | |
| """Retrieve top entity candidates from KG profile or live endpoint.""" | |
| if linker_mode == "embedding_rerank" and endpoint_url and embedding_model_name: | |
| external = retrieve_entities_embedding_rerank( | |
| question, | |
| endpoint_url, | |
| top_k, | |
| embedding_model_name, | |
| embedding_candidate_pool_size, | |
| ) | |
| if external: | |
| return external | |
| if linker_mode == "dbpedia_lookup" and lookup_url: | |
| external = retrieve_entities_dbpedia_lookup(question, top_k, lookup_url) | |
| if external: | |
| return external | |
| if linker_mode == "dbpedia_spotlight" and spotlight_base_url: | |
| external = retrieve_entities_dbpedia_spotlight( | |
| question, | |
| top_k, | |
| spotlight_base_url, | |
| language, | |
| ) | |
| if external: | |
| return external | |
| if endpoint_url: | |
| live_candidates = retrieve_entities_hybrid( | |
| question, | |
| endpoint_url, | |
| top_k, | |
| linker_mode=linker_mode, | |
| ) | |
| if live_candidates: | |
| return live_candidates | |
| entities = kg_profile.get("entities", []) | |
| if not entities: | |
| return [] | |
| q_tokens = _tokenize(question, remove_stop_words=True) | |
| scored = [] | |
| for entity in entities: | |
| label = entity.get("label", "") | |
| uri = entity.get("uri", "") | |
| score = ( | |
| 1.2 * _fuzzy_match_score(q_tokens, label) | |
| + 0.4 * _sequence_similarity(question, label) | |
| ) | |
| if score > 0: | |
| scored.append({ | |
| "uri": uri, | |
| "label": label, | |
| "score": round(score, 4), | |
| }) | |
| scored.sort(key=lambda x: x["score"], reverse=True) | |
| return scored[:top_k] | |
| def _score_schema_candidate( | |
| question_focus: str, | |
| label: str, | |
| uri: str, | |
| *, | |
| prefer_ontology: bool = False, | |
| ) -> float: | |
| """Score a property/class candidate using multiple lexical signals.""" | |
| focus_tokens = _tokenize(question_focus, remove_stop_words=True) | |
| label_norm = _normalize_text(label) | |
| overlap = _fuzzy_match_score(focus_tokens, label_norm) | |
| sequence = _sequence_similarity(question_focus, label_norm) | |
| substring = 1.0 if any(tok in label_norm for tok in focus_tokens if len(tok) > 2) else 0.0 | |
| uri_tail = _sequence_similarity(question_focus, _split_uri_tail(uri)) | |
| score = 1.8 * overlap + 0.8 * sequence + 0.7 * substring + 0.5 * uri_tail | |
| if prefer_ontology and "/ontology/" in uri: | |
| score += 0.2 | |
| return round(score, 4) | |
| def retrieve_relation_candidates( | |
| question: str, kg_profile: dict, top_k: int, linker_mode: str = "internal_min" | |
| ) -> list[dict]: | |
| """Retrieve top relation/property candidates from KG profile.""" | |
| properties = kg_profile.get("properties", []) | |
| if not properties: | |
| return [] | |
| entity_mentions = _extract_entity_mentions(question) | |
| question_focus = _expand_schema_focus( | |
| _build_relation_focus(question, entity_mentions), | |
| linker_mode, | |
| ) | |
| scored = [] | |
| for prop in properties: | |
| label = prop.get("label", "") | |
| uri = prop.get("uri", "") | |
| score = _score_schema_candidate( | |
| question_focus, label, uri, prefer_ontology=True | |
| ) | |
| if score > 0: | |
| scored.append({ | |
| "uri": uri, | |
| "label": label, | |
| "score": score, | |
| }) | |
| scored.sort(key=lambda x: x["score"], reverse=True) | |
| return scored[:top_k] | |
| def retrieve_class_candidates( | |
| question: str, kg_profile: dict, top_k: int, linker_mode: str = "internal_min" | |
| ) -> list[dict]: | |
| """Retrieve top class/type candidates from KG profile.""" | |
| classes = kg_profile.get("classes", []) | |
| if not classes: | |
| return [] | |
| entity_mentions = _extract_entity_mentions(question) | |
| question_focus = _expand_schema_focus( | |
| _build_relation_focus(question, entity_mentions), | |
| linker_mode, | |
| ) | |
| scored = [] | |
| for cls in classes: | |
| label = cls.get("label", "") | |
| uri = cls.get("uri", "") | |
| score = _score_schema_candidate(question_focus, label, uri) | |
| if score > 0: | |
| scored.append({ | |
| "uri": uri, | |
| "label": label, | |
| "score": score, | |
| }) | |
| scored.sort(key=lambda x: x["score"], reverse=True) | |
| return scored[:top_k] | |
| def build_context( | |
| request: QueryRequest, | |
| dataset: DatasetConfig, | |
| runtime: RuntimeConfig, | |
| kg_profile: dict, | |
| ) -> ContextPackage: | |
| """Build the compact context package for generation and expert prompts.""" | |
| question = request.question | |
| answer_type = infer_answer_type(question) | |
| entity_mentions = _extract_entity_mentions(question) | |
| relation_candidates = retrieve_relation_candidates( | |
| question, kg_profile, runtime.relation_top_k, linker_mode=runtime.linker_mode | |
| ) | |
| class_candidates = retrieve_class_candidates( | |
| question, kg_profile, runtime.class_top_k, linker_mode=runtime.linker_mode | |
| ) | |
| entity_candidates = retrieve_entity_candidates( | |
| question, | |
| kg_profile, | |
| runtime.entity_top_k, | |
| endpoint_url=dataset.endpoint_url if dataset.mode == "dbpedia" else None, | |
| linker_mode=runtime.linker_mode, | |
| lookup_url=runtime.dbpedia_lookup_url if dataset.mode == "dbpedia" else None, | |
| spotlight_base_url=runtime.dbpedia_spotlight_base_url if dataset.mode == "dbpedia" else None, | |
| language=request.language, | |
| embedding_model_name=runtime.embedding_model_name, | |
| embedding_candidate_pool_size=runtime.embedding_candidate_pool_size, | |
| ) | |
| notes: list[str] = [] | |
| if entity_mentions: | |
| notes.append(f"Detected entity mentions: {', '.join(entity_mentions)}") | |
| if not entity_candidates: | |
| notes.append("No entity candidates found from endpoint/profile.") | |
| if not relation_candidates: | |
| notes.append("No relation candidates found in KG profile.") | |
| if not class_candidates: | |
| notes.append("No class candidates found in KG profile.") | |
| context = ContextPackage( | |
| entity_candidates=entity_candidates, | |
| relation_candidates=relation_candidates, | |
| class_candidates=class_candidates, | |
| answer_type_hint=answer_type, | |
| prefix_hints=dataset.default_prefixes, | |
| notes=notes, | |
| ) | |
| logger.info( | |
| "Context built: %d entities, %d relations, %d classes, type=%s", | |
| len(entity_candidates), | |
| len(relation_candidates), | |
| len(class_candidates), | |
| answer_type, | |
| ) | |
| return context | |