Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Initial SPARQL generation for the Text2SPARQL repair pipeline. | |
| The current pipeline is intentionally single-candidate: | |
| - generate one SPARQL draft | |
| - repair that draft iteratively | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import re | |
| from typing import Any | |
| from .config import RuntimeConfig | |
| from .llm import LLMClient | |
| from .models import CandidateQuery, ContextPackage, QueryRequest | |
| from .prompts import build_generation_prompt | |
| from .utils import short_hash | |
| logger = logging.getLogger(__name__) | |
| def parse_generation_output(raw_text: str) -> list[str]: | |
| """Parse SPARQL queries from LLM generation output. | |
| Extracts queries from ```sparql ... ``` code blocks. | |
| Args: | |
| raw_text: Raw LLM output text. | |
| Returns: | |
| List of SPARQL query strings. | |
| """ | |
| # Find all ```sparql ... ``` blocks | |
| pattern = r"```sparql\s*\n?(.*?)\n?\s*```" | |
| matches = re.findall(pattern, raw_text, re.DOTALL) | |
| if matches: | |
| return [m.strip() for m in matches if m.strip()] | |
| # Fallback: try ``` ... ``` blocks | |
| pattern = r"```\s*\n?(.*?)\n?\s*```" | |
| matches = re.findall(pattern, raw_text, re.DOTALL) | |
| queries = [] | |
| for m in matches: | |
| m = m.strip() | |
| # Only include if it looks like SPARQL | |
| if any(kw in m.upper() for kw in ("SELECT", "ASK", "CONSTRUCT", "DESCRIBE", "PREFIX")): | |
| queries.append(m) | |
| return queries | |
| def normalize_query(query: str) -> str: | |
| """Normalize a SPARQL query for deduplication. | |
| - Collapses whitespace | |
| - Lowercases keywords | |
| - Strips trailing semicolons/periods | |
| Args: | |
| query: SPARQL query string. | |
| Returns: | |
| Normalized query string. | |
| """ | |
| # Collapse all whitespace to single spaces | |
| normalized = re.sub(r"\s+", " ", query.strip()) | |
| return normalized | |
| def deduplicate_queries(queries: list[str]) -> list[str]: | |
| """Remove duplicate queries after normalization. | |
| Args: | |
| queries: List of SPARQL query strings. | |
| Returns: | |
| List of unique SPARQL query strings (preserving original formatting). | |
| """ | |
| seen_normalized: set[str] = set() | |
| unique: list[str] = [] | |
| for q in queries: | |
| norm = normalize_query(q) | |
| if norm not in seen_normalized: | |
| seen_normalized.add(norm) | |
| unique.append(q) | |
| return unique | |
| def generate_initial_candidates( | |
| request: QueryRequest, | |
| context: ContextPackage, | |
| runtime: RuntimeConfig, | |
| llm: LLMClient, | |
| ) -> list[CandidateQuery]: | |
| """Generate exactly one initial SPARQL candidate. | |
| Args: | |
| request: The query request. | |
| context: Context package with grounding hints. | |
| runtime: Runtime configuration. | |
| llm: LLM client for generation. | |
| Returns: | |
| A single-element list of CandidateQuery objects, or empty on failure. | |
| """ | |
| # The main pipeline is single-path by design. | |
| k = 1 | |
| prompt = build_generation_prompt( | |
| request.question, | |
| context, | |
| k, | |
| prompt_files=runtime.prompt_files, | |
| ) | |
| raw_output = llm.generate_text(prompt) | |
| raw_queries = parse_generation_output(raw_output) | |
| logger.info("Generation produced %d raw queries", len(raw_queries)) | |
| unique_queries = deduplicate_queries(raw_queries) | |
| logger.info("After dedup: %d unique queries", len(unique_queries)) | |
| final_queries = unique_queries[:1] | |
| if len(final_queries) < 1: | |
| logger.warning("Could not generate a valid initial candidate.") | |
| candidates = [] | |
| for i, query in enumerate(final_queries): | |
| cid = f"gen_{i}_{short_hash(query)}" | |
| candidates.append( | |
| CandidateQuery( | |
| candidate_id=cid, | |
| query=query, | |
| source="generation", | |
| generation_index=i, | |
| parent_candidate_id=None, | |
| repair_iteration=0, | |
| ) | |
| ) | |
| logger.info("Generated %d initial candidate(s)", len(candidates)) | |
| return candidates | |