Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Cheap symbolic validation for the Text2SPARQL repair pipeline. | |
| All validation is symbolic — no LLM calls. Scores candidates for selection. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import re | |
| from typing import Any | |
| from .config import RuntimeConfig | |
| from .models import ( | |
| CandidateQuery, | |
| ContextPackage, | |
| DatasetConfig, | |
| QueryRequest, | |
| ValidationResult, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Threshold for "huge result" flag | |
| _HUGE_RESULT_THRESHOLD = 10000 | |
| def parse_check(query: str) -> tuple[bool, str | None]: | |
| """Check whether a SPARQL query parses correctly using rdflib. | |
| Args: | |
| query: SPARQL query string. | |
| Returns: | |
| Tuple of (parse_ok, error_message). | |
| """ | |
| try: | |
| from rdflib.plugins.sparql.parser import parseQuery | |
| parseQuery(query) | |
| return True, None | |
| except ImportError: | |
| # If rdflib is not installed, do a basic structural check | |
| logger.warning("rdflib not installed — using basic parse check") | |
| return _basic_parse_check(query) | |
| except Exception as exc: | |
| return False, str(exc) | |
| def _basic_parse_check(query: str) -> tuple[bool, str | None]: | |
| """Basic structural SPARQL parse check without rdflib. | |
| Checks for balanced braces and required keywords. | |
| """ | |
| q_upper = query.upper() | |
| has_keyword = any( | |
| kw in q_upper | |
| for kw in ("SELECT", "ASK", "CONSTRUCT", "DESCRIBE") | |
| ) | |
| if not has_keyword: | |
| return False, "No SPARQL query keyword found (SELECT/ASK/CONSTRUCT/DESCRIBE)" | |
| # Check balanced braces | |
| open_count = query.count("{") | |
| close_count = query.count("}") | |
| if open_count != close_count: | |
| return False, f"Unbalanced braces: {open_count} open, {close_count} close" | |
| if "WHERE" not in q_upper and "ASK" not in q_upper: | |
| return False, "Missing WHERE clause" | |
| return True, None | |
| def execute_query( | |
| query: str, endpoint_url: str, timeout_sec: int | |
| ) -> tuple[bool, list[dict], int | None, str | None, bool]: | |
| """Execute a SPARQL query against an endpoint. | |
| Args: | |
| query: SPARQL query string. | |
| endpoint_url: SPARQL endpoint URL. | |
| timeout_sec: Request timeout in seconds. | |
| Returns: | |
| Tuple of (execute_ok, results, result_count, error_message, timed_out). | |
| """ | |
| try: | |
| from SPARQLWrapper import SPARQLWrapper, JSON, POST | |
| sparql = SPARQLWrapper(endpoint_url) | |
| sparql.setQuery(query) | |
| sparql.setReturnFormat(JSON) | |
| sparql.setTimeout(timeout_sec) | |
| sparql.setMethod(POST) | |
| raw_results = sparql.query().convert() | |
| # Parse results based on query type | |
| if "boolean" in raw_results: | |
| # ASK query | |
| results = [{"boolean": raw_results["boolean"]}] | |
| return True, results, 1, None, False | |
| if "results" in raw_results and "bindings" in raw_results["results"]: | |
| bindings = raw_results["results"]["bindings"] | |
| result_count = len(bindings) | |
| # Keep only first few for preview | |
| preview = bindings[:5] | |
| results = [ | |
| {k: v.get("value", "") for k, v in row.items()} | |
| for row in preview | |
| ] | |
| return True, results, result_count, None, False | |
| return True, [], 0, None, False | |
| except ImportError: | |
| logger.warning("SPARQLWrapper not installed — skipping endpoint execution") | |
| return False, [], None, "SPARQLWrapper not installed", False | |
| except Exception as exc: | |
| error_str = str(exc) | |
| timed_out = any( | |
| phrase in error_str.lower() | |
| for phrase in ("timeout", "timed out", "time out", "deadline") | |
| ) | |
| return False, [], None, error_str[:500], timed_out | |
| def _detect_query_form(query: str) -> str: | |
| """Detect the SPARQL query form (ASK, SELECT, etc.).""" | |
| q_stripped = re.sub(r"PREFIX\s+\S+\s+<[^>]+>", "", query, flags=re.IGNORECASE) | |
| q_upper = q_stripped.strip().upper() | |
| if q_upper.lstrip().startswith("ASK"): | |
| return "ask" | |
| if "COUNT(" in q_upper or "COUNT (" in q_upper: | |
| return "count" | |
| if q_upper.lstrip().startswith("SELECT"): | |
| return "select" | |
| if q_upper.lstrip().startswith("CONSTRUCT"): | |
| return "construct" | |
| if q_upper.lstrip().startswith("DESCRIBE"): | |
| return "describe" | |
| return "unknown" | |
| def score_answer_type_fit( | |
| question: str, query: str, answer_type_hint: str | |
| ) -> float: | |
| """Score how well the query form matches the expected answer type. | |
| Args: | |
| question: Natural language question. | |
| query: SPARQL query. | |
| answer_type_hint: Expected type ("ask", "count", "select"). | |
| Returns: | |
| Score between 0.0 and 1.0. | |
| """ | |
| query_form = _detect_query_form(query) | |
| if answer_type_hint == "ask": | |
| if query_form == "ask": | |
| return 1.0 | |
| return 0.0 | |
| if answer_type_hint == "count": | |
| if query_form == "count": | |
| return 1.0 | |
| if query_form == "select": | |
| return 0.3 # Select could still work | |
| return 0.0 | |
| if answer_type_hint == "select": | |
| if query_form == "select": | |
| return 1.0 | |
| if query_form == "count": | |
| return 0.3 | |
| return 0.2 | |
| return 0.5 # Unknown hint | |
| def score_schema_fit(query: str, context: ContextPackage) -> float: | |
| """Score how well the query uses entities/relations from the context. | |
| Simple heuristic: checks if context URIs appear in the query. | |
| Args: | |
| query: SPARQL query. | |
| context: Context package with candidates. | |
| Returns: | |
| Score between 0.0 and 1.0. | |
| """ | |
| if not context.entity_candidates and not context.relation_candidates: | |
| return 0.5 # No context to judge against | |
| total_candidates = 0 | |
| matched = 0 | |
| for entity in context.entity_candidates: | |
| uri = entity.get("uri", "") | |
| if uri: | |
| total_candidates += 1 | |
| if uri in query: | |
| matched += 1 | |
| for relation in context.relation_candidates: | |
| uri = relation.get("uri", "") | |
| if uri: | |
| total_candidates += 1 | |
| if uri in query: | |
| matched += 1 | |
| for cls in context.class_candidates: | |
| uri = cls.get("uri", "") | |
| if uri: | |
| total_candidates += 1 | |
| if uri in query: | |
| matched += 1 | |
| if total_candidates == 0: | |
| return 0.5 | |
| return min(1.0, matched / max(1, min(total_candidates, 3))) | |
| def compute_validation_score( | |
| parse_ok: bool, | |
| execute_ok: bool, | |
| result_count: int | None, | |
| answer_type_fit: float, | |
| schema_fit: float, | |
| suspicious_flags: list[str], | |
| weights: dict[str, float], | |
| ) -> float: | |
| """Compute the validation score using the fixed scoring formula. | |
| Formula: | |
| score = + 5.0 if parse_ok | |
| + 5.0 if execute_ok | |
| + 2.0 * answer_type_fit | |
| + 2.0 * schema_fit | |
| - 2.0 if timeout | |
| - 1.5 if empty_result | |
| - 1.0 if huge_result | |
| - 0.5 * len(suspicious_flags) | |
| Args: | |
| parse_ok: Whether query parsed. | |
| execute_ok: Whether query executed. | |
| result_count: Number of results. | |
| answer_type_fit: Answer type fit score [0,1]. | |
| schema_fit: Schema fit score [0,1]. | |
| suspicious_flags: List of suspicious flag strings. | |
| weights: Scoring weights dict. | |
| Returns: | |
| Total validation score. | |
| """ | |
| score = 0.0 | |
| if parse_ok: | |
| score += weights.get("parse_ok", 5.0) | |
| if execute_ok: | |
| score += weights.get("execute_ok", 5.0) | |
| score += weights.get("answer_type_fit", 2.0) * answer_type_fit | |
| score += weights.get("schema_fit", 2.0) * schema_fit | |
| if "timeout" in suspicious_flags: | |
| score += weights.get("timeout", -2.0) | |
| if "empty_result" in suspicious_flags: | |
| score += weights.get("empty_result", -1.5) | |
| if "huge_result" in suspicious_flags: | |
| score += weights.get("huge_result", -1.0) | |
| score += weights.get("suspicious_flag", -0.5) * len(suspicious_flags) | |
| return round(score, 4) | |
| def validate_candidate( | |
| candidate: CandidateQuery, | |
| request: QueryRequest, | |
| context: ContextPackage, | |
| dataset: DatasetConfig, | |
| runtime: RuntimeConfig, | |
| ) -> ValidationResult: | |
| """Validate a single candidate query. | |
| Runs all symbolic checks: | |
| - Parser check | |
| - Endpoint execution | |
| - Timeout check | |
| - Result count check | |
| - Answer type sanity check | |
| - Schema plausibility check | |
| Args: | |
| candidate: The candidate SPARQL query. | |
| request: The original query request. | |
| context: Context package. | |
| dataset: Dataset configuration. | |
| runtime: Runtime configuration. | |
| Returns: | |
| ValidationResult with all check results and score. | |
| """ | |
| flags: list[str] = [] | |
| query = candidate.query | |
| # 1. Parser check | |
| parse_ok, parse_error = parse_check(query) | |
| if not parse_ok: | |
| flags.append("parse_fail") | |
| return ValidationResult( | |
| candidate_id=candidate.candidate_id, | |
| parse_ok=False, | |
| execute_ok=False, | |
| timeout=False, | |
| execution_error=parse_error, | |
| result_count=None, | |
| result_preview=[], | |
| answer_type_fit=0.0, | |
| schema_fit=0.0, | |
| suspicious_flags=flags, | |
| score=compute_validation_score( | |
| False, False, None, 0.0, 0.0, flags, | |
| runtime.selection_weights, | |
| ), | |
| ) | |
| # 2. Endpoint execution | |
| execute_ok, results, result_count, exec_error, timed_out = execute_query( | |
| query, dataset.endpoint_url, runtime.request_timeout_sec | |
| ) | |
| if timed_out: | |
| flags.append("timeout") | |
| if not execute_ok: | |
| flags.append("execute_fail") | |
| if result_count is not None: | |
| if result_count == 0: | |
| flags.append("empty_result") | |
| elif result_count > _HUGE_RESULT_THRESHOLD: | |
| flags.append("huge_result") | |
| # 3. Answer type check | |
| answer_type_hint = context.answer_type_hint or "select" | |
| at_fit = score_answer_type_fit(request.question, query, answer_type_hint) | |
| query_form = _detect_query_form(query) | |
| if answer_type_hint != query_form and query_form != "unknown": | |
| # Only flag if there's a clear mismatch | |
| if not (answer_type_hint == "count" and query_form == "select"): | |
| flags.append("form_mismatch") | |
| # 4. Schema fit | |
| s_fit = score_schema_fit(query, context) | |
| # 5. Compute score | |
| score = compute_validation_score( | |
| parse_ok, execute_ok, result_count, at_fit, s_fit, | |
| flags, runtime.selection_weights, | |
| ) | |
| return ValidationResult( | |
| candidate_id=candidate.candidate_id, | |
| parse_ok=parse_ok, | |
| execute_ok=execute_ok, | |
| timeout=timed_out, | |
| execution_error=exec_error, | |
| result_count=result_count, | |
| result_preview=results, | |
| answer_type_fit=at_fit, | |
| schema_fit=s_fit, | |
| suspicious_flags=flags, | |
| score=score, | |
| ) | |
| def validate_all( | |
| candidates: list[CandidateQuery], | |
| request: QueryRequest, | |
| context: ContextPackage, | |
| dataset: DatasetConfig, | |
| runtime: RuntimeConfig, | |
| ) -> list[ValidationResult]: | |
| """Validate all candidate queries. | |
| Args: | |
| candidates: List of candidate queries. | |
| request: The original query request. | |
| context: Context package. | |
| dataset: Dataset configuration. | |
| runtime: Runtime configuration. | |
| Returns: | |
| List of ValidationResult objects, one per candidate. | |
| """ | |
| results = [] | |
| for candidate in candidates: | |
| logger.info("Validating candidate %s", candidate.candidate_id) | |
| result = validate_candidate(candidate, request, context, dataset, runtime) | |
| logger.info( | |
| "Candidate %s: score=%.2f, flags=%s", | |
| candidate.candidate_id, result.score, result.suspicious_flags, | |
| ) | |
| results.append(result) | |
| return results | |