iris-at-text2sparql / src /context.py
Alex Latipov
Harden frozen eval prompts and judge JSON handling
d745844
"""Context building for the Text2SPARQL repair pipeline.
Builds a compact context package from question analysis and KG profile.
Provides higher-quality grounding hints through hybrid retrieval:
- entities from live endpoint retrieval + lexical reranking
- relations/classes from static schema profile + stronger lexical scoring
"""
from __future__ import annotations
import html
import logging
import re
import urllib.parse
import urllib.request
from functools import lru_cache
from difflib import SequenceMatcher
from typing import Any
from SPARQLWrapper import JSON, POST, SPARQLWrapper
from .config import RuntimeConfig
from .models import ContextPackage, DatasetConfig, QueryRequest
logger = logging.getLogger(__name__)
# ── Answer type inference patterns ───────────────────────────────
_YES_NO_PATTERNS = [
r"^(is|are|was|were|do|does|did|has|have|had|can|could|will|would|should)\b",
r"^(isn't|aren't|wasn't|weren't|don't|doesn't|didn't)\b",
]
_COUNT_PATTERNS = [
r"\bhow\s+many\b",
r"\bnumber\s+of\b",
r"\bcount\s+of\b",
r"\btotal\s+(number|count)\b",
]
_STOP_WORDS = {
"who", "what", "which", "when", "how", "many", "the", "a", "an", "is", "are",
"was", "were", "do", "does", "did", "of", "in", "on", "at", "to", "for",
"with", "by", "from", "and", "or", "but", "about", "be", "been", "being",
"all", "give", "me", "there", "any", "into", "as", "same", "than", "both",
"first", "last", "most", "least", "total",
}
_QUESTION_LEAD_WORDS = {
"what", "which", "who", "where", "when", "how", "give", "list", "name",
}
_ENTITY_CONNECTORS = {"of", "the", "and", "de", "la", "le", "van", "von", "du", "del"}
_LEADING_ROLE_WORDS = {
"ceo", "president", "mayor", "author", "director", "capital", "population",
"king", "queen", "city", "country", "state", "province", "river", "mountain",
"book", "books", "movie", "movies", "film", "album", "song", "band",
}
_COMMON_LABEL_PREDICATES = [
"http://www.w3.org/2000/01/rdf-schema#label",
"http://xmlns.com/foaf/0.1/name",
"http://schema.org/name",
"http://www.w3.org/2004/02/skos/core#prefLabel",
"http://www.w3.org/2004/02/skos/core#altLabel",
]
_COMMON_DESCRIPTION_PREDICATES = [
"http://www.w3.org/2000/01/rdf-schema#comment",
"http://schema.org/description",
"http://dbpedia.org/ontology/abstract",
]
_DBPEDIA_REDIRECT_PREDICATE = "http://dbpedia.org/ontology/wikiPageRedirects"
_RELATION_SYNONYM_MAP = {
"population": ["people", "inhabitants", "residents"],
"mayor": ["leader", "head", "governor"],
"directed": ["director", "direct"],
"wrote": ["author", "written", "write"],
"books": ["book", "written"],
"longest": ["length", "long"],
"capital": ["capital city", "seat"],
}
_TYPE_HINT_KEYWORDS = {
"person": ["person", "actor", "actress", "author", "writer", "director", "politician", "scientist", "musician"],
"place": ["place", "city", "country", "state", "river", "mountain", "village", "town", "location"],
"date": ["date", "year", "time", "period"],
"organization": ["organization", "company", "band", "team", "agency", "airline", "alliance"],
}
_EMBEDDING_MODEL = None
_EMBEDDING_TOKENIZER = None
_EMBEDDING_MODEL_NAME = None
def infer_answer_type(question: str) -> str:
"""Infer the expected SPARQL query form from the question."""
q_lower = question.strip().lower()
for pattern in _COUNT_PATTERNS:
if re.search(pattern, q_lower):
return "count"
for pattern in _YES_NO_PATTERNS:
if re.search(pattern, q_lower):
return "ask"
return "select"
def _tokenize(text: str, remove_stop_words: bool = False) -> list[str]:
"""Simple word tokenization for matching."""
tokens = re.findall(r"[a-zA-Z0-9]+", text.lower())
if remove_stop_words:
return [t for t in tokens if t not in _STOP_WORDS]
return tokens
def _normalize_text(text: str) -> str:
"""Normalize labels/surface forms for lexical matching."""
if not text:
return ""
text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
text = text.replace("_", " ")
text = re.sub(r"[\(\)\[\]\{\},;:/\\|]+", " ", text)
text = re.sub(r"\s+", " ", text.lower()).strip()
return text
def _fuzzy_match_score(query_tokens: list[str], label: str) -> float:
"""Compute a simple overlap score between query tokens and a label."""
label_tokens = set(_tokenize(_normalize_text(label)))
if not label_tokens:
return 0.0
query_set = set(query_tokens)
overlap = query_set & label_tokens
if not overlap:
return 0.0
return len(overlap) / (len(label_tokens) + len(query_set) - len(overlap))
def _sequence_similarity(left: str, right: str) -> float:
"""Cheap character-level similarity for tie-breaking."""
if not left or not right:
return 0.0
return SequenceMatcher(a=_normalize_text(left), b=_normalize_text(right)).ratio()
def _split_uri_tail(uri: str) -> str:
"""Extract a readable tail from a URI for lexical comparison."""
if not uri:
return ""
tail = uri.rsplit("/", 1)[-1].rsplit("#", 1)[-1]
return _normalize_text(tail)
def _strip_highlight_markup(text: str) -> str:
"""Remove simple HTML-ish lookup markup from labels/comments."""
if not text:
return ""
text = re.sub(r"</?B>", "", text)
return html.unescape(text)
def _extract_entity_mentions(question: str) -> list[str]:
"""Extract likely entity surface forms from the question.
This is intentionally heuristic. It targets common DBpedia mention shapes:
- quoted spans
- contiguous capitalized spans
- capitalized spans with light connectors such as "of" or commas
"""
mentions: list[str] = []
# Quoted spans are usually strong mentions.
for match in re.findall(r'"([^"]+)"|\'([^\']+)\'', question):
mention = next((m for m in match if m), "").strip()
if mention:
mentions.append(mention)
tokens = re.findall(r"[A-Za-z0-9][A-Za-z0-9._'-]*|,", question)
current: list[str] = []
def flush() -> None:
nonlocal current
if not current:
return
mention = " ".join(current).replace(" ,", ",").strip(" ,.;:!?")
mention = _postprocess_entity_mention(mention)
if mention:
norm_tokens = _normalize_text(mention).split()
first = norm_tokens[0] if norm_tokens else ""
if first not in _QUESTION_LEAD_WORDS and not (
len(norm_tokens) == 1 and first in _STOP_WORDS
):
mentions.append(mention)
current = []
for tok in tokens:
clean = tok.strip()
lower = clean.lower()
is_number = clean.isdigit()
is_cap = bool(re.match(r"^[A-Z][A-Za-z0-9._'-]*$", clean))
if clean == "," and current:
current.append(clean)
continue
if is_cap or is_number:
current.append(clean)
continue
if current and lower in _ENTITY_CONNECTORS:
current.append(clean)
continue
flush()
flush()
# Deduplicate while keeping longer, more informative spans first.
unique: list[str] = []
seen: set[str] = set()
for mention in sorted(mentions, key=lambda x: (-len(_tokenize(x)), -len(x))):
norm = _normalize_text(mention)
if not norm or norm in seen:
continue
seen.add(norm)
unique.append(mention)
return unique[:6]
def _postprocess_entity_mention(mention: str) -> str:
"""Trim obvious non-entity lead-ins from a detected mention span."""
tokens = mention.split()
if len(tokens) >= 3:
first = tokens[0]
second = tokens[1].lower()
if (
re.fullmatch(r"[A-Z]{2,}", first)
and second == "of"
):
return " ".join(tokens[2:]).strip()
if first.lower() in _LEADING_ROLE_WORDS and second == "of":
return " ".join(tokens[2:]).strip()
return mention
def _expand_entity_surface_forms(mentions: list[str]) -> list[str]:
"""Expand detected mentions into retrieval-friendly surface-form variants.
The goal is still lightweight lexical retrieval, but with more robust forms
for exact/full-text lookup:
- drop leading articles for title-like mentions
- remove commas and parentheticals
- keep comma-split head spans as a fallback disambiguation form
"""
variants: list[str] = []
for mention in mentions:
cleaned = mention.strip(" ,.;:!?")
if not cleaned:
continue
candidates = [cleaned]
no_parens = re.sub(r"\s*\([^)]*\)", "", cleaned).strip()
if no_parens and no_parens != cleaned:
candidates.append(no_parens)
tokens = cleaned.split()
if len(tokens) >= 2 and tokens[0].lower() in {"the", "a", "an"}:
candidates.append(" ".join(tokens[1:]))
if "," in cleaned:
comma_flat = re.sub(r"\s*,\s*", " ", cleaned).strip()
if comma_flat:
candidates.append(comma_flat)
head = cleaned.split(",", 1)[0].strip()
if head:
candidates.append(head)
for candidate in candidates:
candidate = re.sub(r"\s+", " ", candidate).strip(" ,.;:!?")
if candidate:
variants.append(candidate)
unique: list[str] = []
seen: set[str] = set()
for variant in sorted(variants, key=lambda x: (-len(_tokenize(x)), -len(x))):
norm = _normalize_text(variant)
if not norm or norm in seen:
continue
seen.add(norm)
unique.append(variant)
return unique[:10]
def _build_relation_focus(question: str, entity_mentions: list[str]) -> str:
"""Create a question fragment more focused on relation/property semantics."""
focused = question
for mention in entity_mentions:
if mention:
focused = re.sub(re.escape(mention), " ", focused, flags=re.IGNORECASE)
focused = re.sub(r"\s+", " ", focused).strip()
return focused
def _run_sparql_json(query: str, endpoint_url: str, timeout_sec: int = 8) -> dict[str, Any] | None:
"""Execute a small SPARQL JSON query and return the decoded payload."""
try:
sparql = SPARQLWrapper(endpoint_url)
sparql.setQuery(query)
sparql.setReturnFormat(JSON)
sparql.setTimeout(timeout_sec)
sparql.setMethod(POST)
return sparql.query().convert()
except Exception as exc:
logger.warning("SPARQL lookup failed: %s", exc)
return None
def _http_json(
url: str,
*,
headers: dict[str, str] | None = None,
data: bytes | None = None,
timeout_sec: int = 10,
) -> dict[str, Any] | None:
"""Execute a small HTTP request and decode JSON."""
try:
request = urllib.request.Request(
url,
data=data,
headers={
"User-Agent": "Mozilla/5.0",
**(headers or {}),
},
)
with urllib.request.urlopen(request, timeout=timeout_sec) as response:
import json
return json.load(response)
except Exception as exc:
logger.warning("HTTP JSON lookup failed: %s", exc)
return None
def _init_embedding_backend(model_name: str):
"""Lazy-load a lightweight sentence embedding model via transformers."""
global _EMBEDDING_MODEL, _EMBEDDING_TOKENIZER, _EMBEDDING_MODEL_NAME
if _EMBEDDING_MODEL is not None and _EMBEDDING_MODEL_NAME == model_name:
return _EMBEDDING_TOKENIZER, _EMBEDDING_MODEL
from transformers import AutoModel, AutoTokenizer
logger.info("Loading embedding model for linker reranking: %s", model_name)
_EMBEDDING_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
_EMBEDDING_MODEL = AutoModel.from_pretrained(model_name)
_EMBEDDING_MODEL.eval()
_EMBEDDING_MODEL_NAME = model_name
return _EMBEDDING_TOKENIZER, _EMBEDDING_MODEL
def _mean_pool_embeddings(model_output, attention_mask):
import torch
token_embeddings = model_output.last_hidden_state
mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
summed = (token_embeddings * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
return summed / counts
@lru_cache(maxsize=512)
def _embed_single_text(text: str, model_name: str) -> tuple[float, ...]:
"""Embed one text string and cache the vector."""
import torch
tokenizer, model = _init_embedding_backend(model_name)
encoded = tokenizer(
[text],
padding=True,
truncation=True,
return_tensors="pt",
max_length=256,
)
with torch.no_grad():
output = model(**encoded)
pooled = _mean_pool_embeddings(output, encoded["attention_mask"])
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
vector = pooled[0].cpu().tolist()
return tuple(float(x) for x in vector)
def _cosine_from_tuples(left: tuple[float, ...], right: tuple[float, ...]) -> float:
if not left or not right or len(left) != len(right):
return 0.0
return float(sum(a * b for a, b in zip(left, right)))
def _label_predicates_values() -> str:
return " ".join(f"<{p}>" for p in _COMMON_LABEL_PREDICATES)
def _description_predicates_values() -> str:
return " ".join(f"<{p}>" for p in _COMMON_DESCRIPTION_PREDICATES)
def _build_entity_lookup_query(surface_form: str, top_k: int) -> str:
"""Build an exact-label entity lookup query."""
safe_phrase = surface_form.replace('"', "").strip()[:160]
return f"""
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT DISTINCT ?uri (SAMPLE(?label) AS ?label) (SAMPLE(?comment) AS ?comment) WHERE {{
VALUES ?lp {{ {_label_predicates_values()} }}
VALUES ?exactLabel {{ "{safe_phrase}"@en }}
?uri ?lp ?exactLabel .
BIND(?exactLabel AS ?label)
OPTIONAL {{
VALUES ?dp {{ {_description_predicates_values()} }}
?uri ?dp ?comment .
FILTER (LANG(?comment) = 'en')
}}
}}
LIMIT {max(top_k * 6, 20)}
"""
def _build_entity_fulltext_query(surface_form: str, top_k: int) -> str | None:
"""Build a faster full-text query for approximate entity retrieval."""
tokens = [t for t in _tokenize(surface_form, remove_stop_words=True) if len(t) >= 4]
safe_tokens = [t.replace("'", "")[:40] for t in tokens[:4]]
if not safe_tokens:
return None
fulltext = " OR ".join(f"'{tok}*'" for tok in safe_tokens)
return f"""
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT DISTINCT ?uri (SAMPLE(?label) AS ?label) (SAMPLE(?comment) AS ?comment) WHERE {{
VALUES ?lp {{ {_label_predicates_values()} }}
?uri ?lp ?label .
FILTER (LANG(?label) = 'en')
?label bif:contains "{fulltext}" .
OPTIONAL {{
VALUES ?dp {{ {_description_predicates_values()} }}
?uri ?dp ?comment .
FILTER (LANG(?comment) = 'en')
}}
}}
LIMIT {max(top_k * 8, 30)}
"""
def _build_entity_redirect_query(surface_form: str, top_k: int) -> str:
"""Build a redirect-aware entity lookup query for DBpedia."""
safe_phrase = surface_form.replace('"', "").strip()[:160]
return f"""
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT DISTINCT ?uri (SAMPLE(?labelResolved) AS ?label) (SAMPLE(?comment) AS ?comment) WHERE {{
VALUES ?lp {{ {_label_predicates_values()} }}
VALUES ?exactLabel {{ "{safe_phrase}"@en }}
?redirect ?lp ?exactLabel .
?redirect <{_DBPEDIA_REDIRECT_PREDICATE}> ?uri .
OPTIONAL {{
VALUES ?lp2 {{ {_label_predicates_values()} }}
?uri ?lp2 ?labelResolved .
FILTER (LANG(?labelResolved) = 'en')
}}
BIND(COALESCE(?labelResolved, ?exactLabel) AS ?labelResolved)
OPTIONAL {{
VALUES ?dp {{ {_description_predicates_values()} }}
?uri ?dp ?comment .
FILTER (LANG(?comment) = 'en')
}}
}}
LIMIT {max(top_k * 4, 15)}
"""
def _infer_entity_type_preference(question: str) -> str | None:
"""Infer a coarse answer/entity preference from the question wording."""
q = question.strip().lower()
if q.startswith("who"):
return "person"
if q.startswith("where"):
return "place"
if q.startswith("when"):
return "date"
if "airline" in q or "company" in q or "band" in q:
return "organization"
return None
def _type_aware_bonus(question: str, candidate: dict[str, Any]) -> float:
"""Give a small bonus when candidate text matches question type expectations."""
preference = _infer_entity_type_preference(question)
if not preference:
return 0.0
haystack = " ".join(
[
candidate.get("label", ""),
candidate.get("comment", ""),
_split_uri_tail(candidate.get("uri", "")),
]
).lower()
keywords = _TYPE_HINT_KEYWORDS.get(preference, [])
if any(keyword in haystack for keyword in keywords):
return 0.8
return 0.0
def _expand_schema_focus(question_focus: str, linker_mode: str) -> str:
"""Expand relation/class focus text for stronger lexical matching."""
if linker_mode != "internal_max":
return question_focus
tokens = _tokenize(question_focus, remove_stop_words=True)
extras: list[str] = []
for token in tokens:
extras.extend(_RELATION_SYNONYM_MAP.get(token, []))
if not extras:
return question_focus
return f"{question_focus} {' '.join(extras)}".strip()
def _score_entity_candidate(
question: str,
surface_form: str,
candidate: dict[str, Any],
linker_mode: str,
) -> float:
"""Rerank an entity candidate using label, URI-tail, and comment overlap."""
q_tokens = _tokenize(question, remove_stop_words=True)
surface_tokens = _tokenize(surface_form, remove_stop_words=True)
label = candidate.get("label", "")
comment = candidate.get("comment", "")
uri = candidate.get("uri", "")
exact = 1.0 if _normalize_text(label) == _normalize_text(surface_form) else 0.0
contains = 1.0 if _normalize_text(surface_form) in _normalize_text(label) else 0.0
surface_overlap = _fuzzy_match_score(surface_tokens or q_tokens, label)
question_overlap = _fuzzy_match_score(q_tokens, label)
tail_similarity = _sequence_similarity(surface_form, _split_uri_tail(uri))
comment_overlap = _fuzzy_match_score(q_tokens, comment)
score = (
3.0 * exact
+ 1.75 * contains
+ 1.5 * surface_overlap
+ 0.9 * question_overlap
+ 0.9 * tail_similarity
+ 0.8 * comment_overlap
)
if "http://dbpedia.org/resource/" in uri:
score += 0.1
if linker_mode == "internal_max":
score += _type_aware_bonus(question, candidate)
tail = _split_uri_tail(uri)
if any(flag in tail for flag in ("film", "album", "disambiguation")):
score -= 0.3
return round(score, 4)
def retrieve_entities_hybrid(
question: str,
endpoint_url: str,
top_k: int,
linker_mode: str = "internal_min",
) -> list[dict]:
"""Retrieve and rerank entity candidates from a live endpoint."""
mentions = _extract_entity_mentions(question)
surface_forms = _expand_entity_surface_forms(mentions)
if not surface_forms:
fallback = " ".join(_tokenize(question, remove_stop_words=True)[:5])
surface_forms = [fallback] if fallback else []
merged: dict[str, dict[str, Any]] = {}
for mention in surface_forms:
queries = [_build_entity_lookup_query(mention, top_k)]
fuzzy_query = _build_entity_fulltext_query(mention, top_k)
if fuzzy_query is not None:
queries.append(fuzzy_query)
if linker_mode == "internal_max":
queries.append(_build_entity_redirect_query(mention, top_k))
for idx, query in enumerate(queries):
raw = _run_sparql_json(query, endpoint_url, timeout_sec=4 if idx == 0 else 5)
if not raw:
continue
for row in raw.get("results", {}).get("bindings", []):
uri = row.get("uri", {}).get("value", "")
label = row.get("label", {}).get("value", "")
comment = row.get("comment", {}).get("value", "")
if not uri or not label:
continue
candidate = merged.get(uri, {
"uri": uri,
"label": label,
"comment": comment,
"matched_mentions": [],
"score": 0.0,
})
if mention not in candidate["matched_mentions"]:
candidate["matched_mentions"].append(mention)
if comment and not candidate.get("comment"):
candidate["comment"] = comment
score = _score_entity_candidate(question, mention, candidate, linker_mode)
candidate["score"] = max(candidate["score"], score)
merged[uri] = candidate
ranked = sorted(
merged.values(),
key=lambda x: (-x["score"], len(_normalize_text(x.get("label", "")))),
)
return ranked[:top_k]
def retrieve_entities_dbpedia_lookup(
question: str,
top_k: int,
lookup_url: str,
) -> list[dict]:
"""Retrieve DBpedia entities using the official Lookup service."""
mentions = _expand_entity_surface_forms(_extract_entity_mentions(question))
if not mentions:
fallback = " ".join(_tokenize(question, remove_stop_words=True)[:5])
mentions = [fallback] if fallback else []
merged: dict[str, dict[str, Any]] = {}
for mention in mentions[:6]:
params = urllib.parse.urlencode(
{
"format": "json",
"query": mention,
"maxResults": str(max(top_k * 5, 10)),
}
)
payload = _http_json(
f"{lookup_url}?{params}",
headers={"Accept": "application/json"},
timeout_sec=12,
)
if not payload:
continue
for row in payload.get("docs", []):
uri = next(iter(row.get("resource", [])), "") or next(iter(row.get("id", [])), "")
label = _strip_highlight_markup(next(iter(row.get("label", [])), ""))
comment = _strip_highlight_markup(next(iter(row.get("comment", [])), ""))
score_raw = next(iter(row.get("score", [])), "0")
if not uri:
continue
candidate = merged.get(
uri,
{
"uri": uri,
"label": label or _split_uri_tail(uri),
"comment": comment,
"score": 0.0,
},
)
lexical = _score_entity_candidate(question, mention, candidate, "internal_max")
try:
backend_score = float(score_raw)
except Exception:
backend_score = 0.0
candidate["score"] = max(candidate["score"], lexical + min(backend_score / 10000.0, 2.0))
merged[uri] = candidate
ranked = sorted(merged.values(), key=lambda item: (-item["score"], item.get("label", "")))
return ranked[:top_k]
def retrieve_entities_dbpedia_spotlight(
question: str,
top_k: int,
spotlight_base_url: str,
language: str | None,
) -> list[dict]:
"""Retrieve DBpedia entities using the official Spotlight API."""
lang = (language or "en").lower()
annotate_url = f"{spotlight_base_url.rstrip('/')}/{lang}/annotate"
payload = _http_json(
annotate_url,
headers={"Accept": "application/json"},
data=urllib.parse.urlencode(
{"text": question, "confidence": "0.35"}
).encode("utf-8"),
timeout_sec=15,
)
if not payload:
return []
resources = payload.get("Resources", [])
if isinstance(resources, dict):
resources = [resources]
merged: dict[str, dict[str, Any]] = {}
for row in resources:
uri = row.get("@URI", "")
surface_form = row.get("@surfaceForm", "")
if not uri:
continue
candidate = merged.get(
uri,
{
"uri": uri,
"label": surface_form or _split_uri_tail(uri),
"comment": "",
"score": 0.0,
},
)
lexical = _score_entity_candidate(question, surface_form or _split_uri_tail(uri), candidate, "internal_max")
try:
similarity = float(row.get("@similarityScore", "0") or 0.0)
except Exception:
similarity = 0.0
try:
support = float(row.get("@support", "0") or 0.0)
except Exception:
support = 0.0
candidate["score"] = max(candidate["score"], lexical + 1.5 * similarity + min(support / 10000.0, 1.0))
merged[uri] = candidate
ranked = sorted(merged.values(), key=lambda item: (-item["score"], item.get("label", "")))
return ranked[:top_k]
def retrieve_entities_embedding_rerank(
question: str,
endpoint_url: str,
top_k: int,
model_name: str,
candidate_pool_size: int,
) -> list[dict]:
"""Hybrid entity linker: lexical retrieval first, embeddings for reranking."""
pool_size = max(candidate_pool_size, top_k, 8)
lexical_candidates = retrieve_entities_hybrid(
question,
endpoint_url,
pool_size,
linker_mode="internal_max",
)
if not lexical_candidates:
return []
mentions = _expand_entity_surface_forms(_extract_entity_mentions(question))
query_views = [question] + mentions
query_vectors = [_embed_single_text(text, model_name) for text in query_views if text.strip()]
reranked: list[dict[str, Any]] = []
for candidate in lexical_candidates:
candidate_text = " ".join(
[
candidate.get("label", ""),
candidate.get("comment", ""),
_split_uri_tail(candidate.get("uri", "")),
]
).strip()
if not candidate_text:
candidate_text = candidate.get("uri", "")
candidate_vector = _embed_single_text(candidate_text, model_name)
semantic_similarity = max(
(_cosine_from_tuples(query_vector, candidate_vector) for query_vector in query_vectors),
default=0.0,
)
candidate_copy = dict(candidate)
candidate_copy["lexical_score"] = candidate.get("score", 0.0)
candidate_copy["semantic_similarity"] = round(semantic_similarity, 4)
candidate_copy["score"] = round(
0.45 * candidate.get("score", 0.0) + 4.0 * semantic_similarity,
4,
)
reranked.append(candidate_copy)
reranked.sort(key=lambda item: (-item["score"], item.get("label", "")))
return reranked[:top_k]
def retrieve_entity_candidates(
question: str,
kg_profile: dict,
top_k: int,
endpoint_url: str | None = None,
linker_mode: str = "internal_min",
lookup_url: str | None = None,
spotlight_base_url: str | None = None,
language: str | None = None,
embedding_model_name: str | None = None,
embedding_candidate_pool_size: int = 16,
) -> list[dict]:
"""Retrieve top entity candidates from KG profile or live endpoint."""
if linker_mode == "embedding_rerank" and endpoint_url and embedding_model_name:
external = retrieve_entities_embedding_rerank(
question,
endpoint_url,
top_k,
embedding_model_name,
embedding_candidate_pool_size,
)
if external:
return external
if linker_mode == "dbpedia_lookup" and lookup_url:
external = retrieve_entities_dbpedia_lookup(question, top_k, lookup_url)
if external:
return external
if linker_mode == "dbpedia_spotlight" and spotlight_base_url:
external = retrieve_entities_dbpedia_spotlight(
question,
top_k,
spotlight_base_url,
language,
)
if external:
return external
if endpoint_url:
live_candidates = retrieve_entities_hybrid(
question,
endpoint_url,
top_k,
linker_mode=linker_mode,
)
if live_candidates:
return live_candidates
entities = kg_profile.get("entities", [])
if not entities:
return []
q_tokens = _tokenize(question, remove_stop_words=True)
scored = []
for entity in entities:
label = entity.get("label", "")
uri = entity.get("uri", "")
score = (
1.2 * _fuzzy_match_score(q_tokens, label)
+ 0.4 * _sequence_similarity(question, label)
)
if score > 0:
scored.append({
"uri": uri,
"label": label,
"score": round(score, 4),
})
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:top_k]
def _score_schema_candidate(
question_focus: str,
label: str,
uri: str,
*,
prefer_ontology: bool = False,
) -> float:
"""Score a property/class candidate using multiple lexical signals."""
focus_tokens = _tokenize(question_focus, remove_stop_words=True)
label_norm = _normalize_text(label)
overlap = _fuzzy_match_score(focus_tokens, label_norm)
sequence = _sequence_similarity(question_focus, label_norm)
substring = 1.0 if any(tok in label_norm for tok in focus_tokens if len(tok) > 2) else 0.0
uri_tail = _sequence_similarity(question_focus, _split_uri_tail(uri))
score = 1.8 * overlap + 0.8 * sequence + 0.7 * substring + 0.5 * uri_tail
if prefer_ontology and "/ontology/" in uri:
score += 0.2
return round(score, 4)
def retrieve_relation_candidates(
question: str, kg_profile: dict, top_k: int, linker_mode: str = "internal_min"
) -> list[dict]:
"""Retrieve top relation/property candidates from KG profile."""
properties = kg_profile.get("properties", [])
if not properties:
return []
entity_mentions = _extract_entity_mentions(question)
question_focus = _expand_schema_focus(
_build_relation_focus(question, entity_mentions),
linker_mode,
)
scored = []
for prop in properties:
label = prop.get("label", "")
uri = prop.get("uri", "")
score = _score_schema_candidate(
question_focus, label, uri, prefer_ontology=True
)
if score > 0:
scored.append({
"uri": uri,
"label": label,
"score": score,
})
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:top_k]
def retrieve_class_candidates(
question: str, kg_profile: dict, top_k: int, linker_mode: str = "internal_min"
) -> list[dict]:
"""Retrieve top class/type candidates from KG profile."""
classes = kg_profile.get("classes", [])
if not classes:
return []
entity_mentions = _extract_entity_mentions(question)
question_focus = _expand_schema_focus(
_build_relation_focus(question, entity_mentions),
linker_mode,
)
scored = []
for cls in classes:
label = cls.get("label", "")
uri = cls.get("uri", "")
score = _score_schema_candidate(question_focus, label, uri)
if score > 0:
scored.append({
"uri": uri,
"label": label,
"score": score,
})
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:top_k]
def build_context(
request: QueryRequest,
dataset: DatasetConfig,
runtime: RuntimeConfig,
kg_profile: dict,
) -> ContextPackage:
"""Build the compact context package for generation and expert prompts."""
question = request.question
answer_type = infer_answer_type(question)
entity_mentions = _extract_entity_mentions(question)
relation_candidates = retrieve_relation_candidates(
question, kg_profile, runtime.relation_top_k, linker_mode=runtime.linker_mode
)
class_candidates = retrieve_class_candidates(
question, kg_profile, runtime.class_top_k, linker_mode=runtime.linker_mode
)
entity_candidates = retrieve_entity_candidates(
question,
kg_profile,
runtime.entity_top_k,
endpoint_url=dataset.endpoint_url if dataset.mode == "dbpedia" else None,
linker_mode=runtime.linker_mode,
lookup_url=runtime.dbpedia_lookup_url if dataset.mode == "dbpedia" else None,
spotlight_base_url=runtime.dbpedia_spotlight_base_url if dataset.mode == "dbpedia" else None,
language=request.language,
embedding_model_name=runtime.embedding_model_name,
embedding_candidate_pool_size=runtime.embedding_candidate_pool_size,
)
notes: list[str] = []
if entity_mentions:
notes.append(f"Detected entity mentions: {', '.join(entity_mentions)}")
if not entity_candidates:
notes.append("No entity candidates found from endpoint/profile.")
if not relation_candidates:
notes.append("No relation candidates found in KG profile.")
if not class_candidates:
notes.append("No class candidates found in KG profile.")
context = ContextPackage(
entity_candidates=entity_candidates,
relation_candidates=relation_candidates,
class_candidates=class_candidates,
answer_type_hint=answer_type,
prefix_hints=dataset.default_prefixes,
notes=notes,
)
logger.info(
"Context built: %d entities, %d relations, %d classes, type=%s",
len(entity_candidates),
len(relation_candidates),
len(class_candidates),
answer_type,
)
return context