iris-at-text2sparql / src /generation.py
Alex Latipov
Harden frozen eval prompts and judge JSON handling
d745844
"""Initial SPARQL generation for the Text2SPARQL repair pipeline.
The current pipeline is intentionally single-candidate:
- generate one SPARQL draft
- repair that draft iteratively
"""
from __future__ import annotations
import logging
import re
from typing import Any
from .config import RuntimeConfig
from .llm import LLMClient
from .models import CandidateQuery, ContextPackage, QueryRequest
from .prompts import build_generation_prompt
from .utils import short_hash
logger = logging.getLogger(__name__)
def parse_generation_output(raw_text: str) -> list[str]:
"""Parse SPARQL queries from LLM generation output.
Extracts queries from ```sparql ... ``` code blocks.
Args:
raw_text: Raw LLM output text.
Returns:
List of SPARQL query strings.
"""
# Find all ```sparql ... ``` blocks
pattern = r"```sparql\s*\n?(.*?)\n?\s*```"
matches = re.findall(pattern, raw_text, re.DOTALL)
if matches:
return [m.strip() for m in matches if m.strip()]
# Fallback: try ``` ... ``` blocks
pattern = r"```\s*\n?(.*?)\n?\s*```"
matches = re.findall(pattern, raw_text, re.DOTALL)
queries = []
for m in matches:
m = m.strip()
# Only include if it looks like SPARQL
if any(kw in m.upper() for kw in ("SELECT", "ASK", "CONSTRUCT", "DESCRIBE", "PREFIX")):
queries.append(m)
return queries
def normalize_query(query: str) -> str:
"""Normalize a SPARQL query for deduplication.
- Collapses whitespace
- Lowercases keywords
- Strips trailing semicolons/periods
Args:
query: SPARQL query string.
Returns:
Normalized query string.
"""
# Collapse all whitespace to single spaces
normalized = re.sub(r"\s+", " ", query.strip())
return normalized
def deduplicate_queries(queries: list[str]) -> list[str]:
"""Remove duplicate queries after normalization.
Args:
queries: List of SPARQL query strings.
Returns:
List of unique SPARQL query strings (preserving original formatting).
"""
seen_normalized: set[str] = set()
unique: list[str] = []
for q in queries:
norm = normalize_query(q)
if norm not in seen_normalized:
seen_normalized.add(norm)
unique.append(q)
return unique
def generate_initial_candidates(
request: QueryRequest,
context: ContextPackage,
runtime: RuntimeConfig,
llm: LLMClient,
) -> list[CandidateQuery]:
"""Generate exactly one initial SPARQL candidate.
Args:
request: The query request.
context: Context package with grounding hints.
runtime: Runtime configuration.
llm: LLM client for generation.
Returns:
A single-element list of CandidateQuery objects, or empty on failure.
"""
# The main pipeline is single-path by design.
k = 1
prompt = build_generation_prompt(
request.question,
context,
k,
prompt_files=runtime.prompt_files,
)
raw_output = llm.generate_text(prompt)
raw_queries = parse_generation_output(raw_output)
logger.info("Generation produced %d raw queries", len(raw_queries))
unique_queries = deduplicate_queries(raw_queries)
logger.info("After dedup: %d unique queries", len(unique_queries))
final_queries = unique_queries[:1]
if len(final_queries) < 1:
logger.warning("Could not generate a valid initial candidate.")
candidates = []
for i, query in enumerate(final_queries):
cid = f"gen_{i}_{short_hash(query)}"
candidates.append(
CandidateQuery(
candidate_id=cid,
query=query,
source="generation",
generation_index=i,
parent_candidate_id=None,
repair_iteration=0,
)
)
logger.info("Generated %d initial candidate(s)", len(candidates))
return candidates