File size: 3,973 Bytes
d745844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""Initial SPARQL generation for the Text2SPARQL repair pipeline.

The current pipeline is intentionally single-candidate:
- generate one SPARQL draft
- repair that draft iteratively
"""

from __future__ import annotations

import logging
import re
from typing import Any

from .config import RuntimeConfig
from .llm import LLMClient
from .models import CandidateQuery, ContextPackage, QueryRequest
from .prompts import build_generation_prompt
from .utils import short_hash

logger = logging.getLogger(__name__)


def parse_generation_output(raw_text: str) -> list[str]:
    """Parse SPARQL queries from LLM generation output.

    Extracts queries from ```sparql ... ``` code blocks.

    Args:
        raw_text: Raw LLM output text.

    Returns:
        List of SPARQL query strings.
    """
    # Find all ```sparql ... ``` blocks
    pattern = r"```sparql\s*\n?(.*?)\n?\s*```"
    matches = re.findall(pattern, raw_text, re.DOTALL)

    if matches:
        return [m.strip() for m in matches if m.strip()]

    # Fallback: try ``` ... ``` blocks
    pattern = r"```\s*\n?(.*?)\n?\s*```"
    matches = re.findall(pattern, raw_text, re.DOTALL)

    queries = []
    for m in matches:
        m = m.strip()
        # Only include if it looks like SPARQL
        if any(kw in m.upper() for kw in ("SELECT", "ASK", "CONSTRUCT", "DESCRIBE", "PREFIX")):
            queries.append(m)

    return queries


def normalize_query(query: str) -> str:
    """Normalize a SPARQL query for deduplication.

    - Collapses whitespace
    - Lowercases keywords
    - Strips trailing semicolons/periods

    Args:
        query: SPARQL query string.

    Returns:
        Normalized query string.
    """
    # Collapse all whitespace to single spaces
    normalized = re.sub(r"\s+", " ", query.strip())
    return normalized


def deduplicate_queries(queries: list[str]) -> list[str]:
    """Remove duplicate queries after normalization.

    Args:
        queries: List of SPARQL query strings.

    Returns:
        List of unique SPARQL query strings (preserving original formatting).
    """
    seen_normalized: set[str] = set()
    unique: list[str] = []

    for q in queries:
        norm = normalize_query(q)
        if norm not in seen_normalized:
            seen_normalized.add(norm)
            unique.append(q)

    return unique


def generate_initial_candidates(
    request: QueryRequest,
    context: ContextPackage,
    runtime: RuntimeConfig,
    llm: LLMClient,
) -> list[CandidateQuery]:
    """Generate exactly one initial SPARQL candidate.

    Args:
        request: The query request.
        context: Context package with grounding hints.
        runtime: Runtime configuration.
        llm: LLM client for generation.

    Returns:
        A single-element list of CandidateQuery objects, or empty on failure.
    """
    # The main pipeline is single-path by design.
    k = 1

    prompt = build_generation_prompt(
        request.question,
        context,
        k,
        prompt_files=runtime.prompt_files,
    )
    raw_output = llm.generate_text(prompt)
    raw_queries = parse_generation_output(raw_output)

    logger.info("Generation produced %d raw queries", len(raw_queries))

    unique_queries = deduplicate_queries(raw_queries)
    logger.info("After dedup: %d unique queries", len(unique_queries))
    final_queries = unique_queries[:1]

    if len(final_queries) < 1:
        logger.warning("Could not generate a valid initial candidate.")

    candidates = []
    for i, query in enumerate(final_queries):
        cid = f"gen_{i}_{short_hash(query)}"
        candidates.append(
            CandidateQuery(
                candidate_id=cid,
                query=query,
                source="generation",
                generation_index=i,
                parent_candidate_id=None,
                repair_iteration=0,
            )
        )

    logger.info("Generated %d initial candidate(s)", len(candidates))
    return candidates