"""Initial SPARQL generation for the Text2SPARQL repair pipeline. The current pipeline is intentionally single-candidate: - generate one SPARQL draft - repair that draft iteratively """ from __future__ import annotations import logging import re from typing import Any from .config import RuntimeConfig from .llm import LLMClient from .models import CandidateQuery, ContextPackage, QueryRequest from .prompts import build_generation_prompt from .utils import short_hash logger = logging.getLogger(__name__) def parse_generation_output(raw_text: str) -> list[str]: """Parse SPARQL queries from LLM generation output. Extracts queries from ```sparql ... ``` code blocks. Args: raw_text: Raw LLM output text. Returns: List of SPARQL query strings. """ # Find all ```sparql ... ``` blocks pattern = r"```sparql\s*\n?(.*?)\n?\s*```" matches = re.findall(pattern, raw_text, re.DOTALL) if matches: return [m.strip() for m in matches if m.strip()] # Fallback: try ``` ... ``` blocks pattern = r"```\s*\n?(.*?)\n?\s*```" matches = re.findall(pattern, raw_text, re.DOTALL) queries = [] for m in matches: m = m.strip() # Only include if it looks like SPARQL if any(kw in m.upper() for kw in ("SELECT", "ASK", "CONSTRUCT", "DESCRIBE", "PREFIX")): queries.append(m) return queries def normalize_query(query: str) -> str: """Normalize a SPARQL query for deduplication. - Collapses whitespace - Lowercases keywords - Strips trailing semicolons/periods Args: query: SPARQL query string. Returns: Normalized query string. """ # Collapse all whitespace to single spaces normalized = re.sub(r"\s+", " ", query.strip()) return normalized def deduplicate_queries(queries: list[str]) -> list[str]: """Remove duplicate queries after normalization. Args: queries: List of SPARQL query strings. Returns: List of unique SPARQL query strings (preserving original formatting). """ seen_normalized: set[str] = set() unique: list[str] = [] for q in queries: norm = normalize_query(q) if norm not in seen_normalized: seen_normalized.add(norm) unique.append(q) return unique def generate_initial_candidates( request: QueryRequest, context: ContextPackage, runtime: RuntimeConfig, llm: LLMClient, ) -> list[CandidateQuery]: """Generate exactly one initial SPARQL candidate. Args: request: The query request. context: Context package with grounding hints. runtime: Runtime configuration. llm: LLM client for generation. Returns: A single-element list of CandidateQuery objects, or empty on failure. """ # The main pipeline is single-path by design. k = 1 prompt = build_generation_prompt( request.question, context, k, prompt_files=runtime.prompt_files, ) raw_output = llm.generate_text(prompt) raw_queries = parse_generation_output(raw_output) logger.info("Generation produced %d raw queries", len(raw_queries)) unique_queries = deduplicate_queries(raw_queries) logger.info("After dedup: %d unique queries", len(unique_queries)) final_queries = unique_queries[:1] if len(final_queries) < 1: logger.warning("Could not generate a valid initial candidate.") candidates = [] for i, query in enumerate(final_queries): cid = f"gen_{i}_{short_hash(query)}" candidates.append( CandidateQuery( candidate_id=cid, query=query, source="generation", generation_index=i, parent_candidate_id=None, repair_iteration=0, ) ) logger.info("Generated %d initial candidate(s)", len(candidates)) return candidates