Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
File size: 3,973 Bytes
d745844 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """Initial SPARQL generation for the Text2SPARQL repair pipeline.
The current pipeline is intentionally single-candidate:
- generate one SPARQL draft
- repair that draft iteratively
"""
from __future__ import annotations
import logging
import re
from typing import Any
from .config import RuntimeConfig
from .llm import LLMClient
from .models import CandidateQuery, ContextPackage, QueryRequest
from .prompts import build_generation_prompt
from .utils import short_hash
logger = logging.getLogger(__name__)
def parse_generation_output(raw_text: str) -> list[str]:
"""Parse SPARQL queries from LLM generation output.
Extracts queries from ```sparql ... ``` code blocks.
Args:
raw_text: Raw LLM output text.
Returns:
List of SPARQL query strings.
"""
# Find all ```sparql ... ``` blocks
pattern = r"```sparql\s*\n?(.*?)\n?\s*```"
matches = re.findall(pattern, raw_text, re.DOTALL)
if matches:
return [m.strip() for m in matches if m.strip()]
# Fallback: try ``` ... ``` blocks
pattern = r"```\s*\n?(.*?)\n?\s*```"
matches = re.findall(pattern, raw_text, re.DOTALL)
queries = []
for m in matches:
m = m.strip()
# Only include if it looks like SPARQL
if any(kw in m.upper() for kw in ("SELECT", "ASK", "CONSTRUCT", "DESCRIBE", "PREFIX")):
queries.append(m)
return queries
def normalize_query(query: str) -> str:
"""Normalize a SPARQL query for deduplication.
- Collapses whitespace
- Lowercases keywords
- Strips trailing semicolons/periods
Args:
query: SPARQL query string.
Returns:
Normalized query string.
"""
# Collapse all whitespace to single spaces
normalized = re.sub(r"\s+", " ", query.strip())
return normalized
def deduplicate_queries(queries: list[str]) -> list[str]:
"""Remove duplicate queries after normalization.
Args:
queries: List of SPARQL query strings.
Returns:
List of unique SPARQL query strings (preserving original formatting).
"""
seen_normalized: set[str] = set()
unique: list[str] = []
for q in queries:
norm = normalize_query(q)
if norm not in seen_normalized:
seen_normalized.add(norm)
unique.append(q)
return unique
def generate_initial_candidates(
request: QueryRequest,
context: ContextPackage,
runtime: RuntimeConfig,
llm: LLMClient,
) -> list[CandidateQuery]:
"""Generate exactly one initial SPARQL candidate.
Args:
request: The query request.
context: Context package with grounding hints.
runtime: Runtime configuration.
llm: LLM client for generation.
Returns:
A single-element list of CandidateQuery objects, or empty on failure.
"""
# The main pipeline is single-path by design.
k = 1
prompt = build_generation_prompt(
request.question,
context,
k,
prompt_files=runtime.prompt_files,
)
raw_output = llm.generate_text(prompt)
raw_queries = parse_generation_output(raw_output)
logger.info("Generation produced %d raw queries", len(raw_queries))
unique_queries = deduplicate_queries(raw_queries)
logger.info("After dedup: %d unique queries", len(unique_queries))
final_queries = unique_queries[:1]
if len(final_queries) < 1:
logger.warning("Could not generate a valid initial candidate.")
candidates = []
for i, query in enumerate(final_queries):
cid = f"gen_{i}_{short_hash(query)}"
candidates.append(
CandidateQuery(
candidate_id=cid,
query=query,
source="generation",
generation_index=i,
parent_candidate_id=None,
repair_iteration=0,
)
)
logger.info("Generated %d initial candidate(s)", len(candidates))
return candidates
|