"""Cheap symbolic validation for the Text2SPARQL repair pipeline. All validation is symbolic — no LLM calls. Scores candidates for selection. """ from __future__ import annotations import logging import re from typing import Any from .config import RuntimeConfig from .models import ( CandidateQuery, ContextPackage, DatasetConfig, QueryRequest, ValidationResult, ) logger = logging.getLogger(__name__) # Threshold for "huge result" flag _HUGE_RESULT_THRESHOLD = 10000 def parse_check(query: str) -> tuple[bool, str | None]: """Check whether a SPARQL query parses correctly using rdflib. Args: query: SPARQL query string. Returns: Tuple of (parse_ok, error_message). """ try: from rdflib.plugins.sparql.parser import parseQuery parseQuery(query) return True, None except ImportError: # If rdflib is not installed, do a basic structural check logger.warning("rdflib not installed — using basic parse check") return _basic_parse_check(query) except Exception as exc: return False, str(exc) def _basic_parse_check(query: str) -> tuple[bool, str | None]: """Basic structural SPARQL parse check without rdflib. Checks for balanced braces and required keywords. """ q_upper = query.upper() has_keyword = any( kw in q_upper for kw in ("SELECT", "ASK", "CONSTRUCT", "DESCRIBE") ) if not has_keyword: return False, "No SPARQL query keyword found (SELECT/ASK/CONSTRUCT/DESCRIBE)" # Check balanced braces open_count = query.count("{") close_count = query.count("}") if open_count != close_count: return False, f"Unbalanced braces: {open_count} open, {close_count} close" if "WHERE" not in q_upper and "ASK" not in q_upper: return False, "Missing WHERE clause" return True, None def execute_query( query: str, endpoint_url: str, timeout_sec: int ) -> tuple[bool, list[dict], int | None, str | None, bool]: """Execute a SPARQL query against an endpoint. Args: query: SPARQL query string. endpoint_url: SPARQL endpoint URL. timeout_sec: Request timeout in seconds. Returns: Tuple of (execute_ok, results, result_count, error_message, timed_out). """ try: from SPARQLWrapper import SPARQLWrapper, JSON, POST sparql = SPARQLWrapper(endpoint_url) sparql.setQuery(query) sparql.setReturnFormat(JSON) sparql.setTimeout(timeout_sec) sparql.setMethod(POST) raw_results = sparql.query().convert() # Parse results based on query type if "boolean" in raw_results: # ASK query results = [{"boolean": raw_results["boolean"]}] return True, results, 1, None, False if "results" in raw_results and "bindings" in raw_results["results"]: bindings = raw_results["results"]["bindings"] result_count = len(bindings) # Keep only first few for preview preview = bindings[:5] results = [ {k: v.get("value", "") for k, v in row.items()} for row in preview ] return True, results, result_count, None, False return True, [], 0, None, False except ImportError: logger.warning("SPARQLWrapper not installed — skipping endpoint execution") return False, [], None, "SPARQLWrapper not installed", False except Exception as exc: error_str = str(exc) timed_out = any( phrase in error_str.lower() for phrase in ("timeout", "timed out", "time out", "deadline") ) return False, [], None, error_str[:500], timed_out def _detect_query_form(query: str) -> str: """Detect the SPARQL query form (ASK, SELECT, etc.).""" q_stripped = re.sub(r"PREFIX\s+\S+\s+<[^>]+>", "", query, flags=re.IGNORECASE) q_upper = q_stripped.strip().upper() if q_upper.lstrip().startswith("ASK"): return "ask" if "COUNT(" in q_upper or "COUNT (" in q_upper: return "count" if q_upper.lstrip().startswith("SELECT"): return "select" if q_upper.lstrip().startswith("CONSTRUCT"): return "construct" if q_upper.lstrip().startswith("DESCRIBE"): return "describe" return "unknown" def score_answer_type_fit( question: str, query: str, answer_type_hint: str ) -> float: """Score how well the query form matches the expected answer type. Args: question: Natural language question. query: SPARQL query. answer_type_hint: Expected type ("ask", "count", "select"). Returns: Score between 0.0 and 1.0. """ query_form = _detect_query_form(query) if answer_type_hint == "ask": if query_form == "ask": return 1.0 return 0.0 if answer_type_hint == "count": if query_form == "count": return 1.0 if query_form == "select": return 0.3 # Select could still work return 0.0 if answer_type_hint == "select": if query_form == "select": return 1.0 if query_form == "count": return 0.3 return 0.2 return 0.5 # Unknown hint def score_schema_fit(query: str, context: ContextPackage) -> float: """Score how well the query uses entities/relations from the context. Simple heuristic: checks if context URIs appear in the query. Args: query: SPARQL query. context: Context package with candidates. Returns: Score between 0.0 and 1.0. """ if not context.entity_candidates and not context.relation_candidates: return 0.5 # No context to judge against total_candidates = 0 matched = 0 for entity in context.entity_candidates: uri = entity.get("uri", "") if uri: total_candidates += 1 if uri in query: matched += 1 for relation in context.relation_candidates: uri = relation.get("uri", "") if uri: total_candidates += 1 if uri in query: matched += 1 for cls in context.class_candidates: uri = cls.get("uri", "") if uri: total_candidates += 1 if uri in query: matched += 1 if total_candidates == 0: return 0.5 return min(1.0, matched / max(1, min(total_candidates, 3))) def compute_validation_score( parse_ok: bool, execute_ok: bool, result_count: int | None, answer_type_fit: float, schema_fit: float, suspicious_flags: list[str], weights: dict[str, float], ) -> float: """Compute the validation score using the fixed scoring formula. Formula: score = + 5.0 if parse_ok + 5.0 if execute_ok + 2.0 * answer_type_fit + 2.0 * schema_fit - 2.0 if timeout - 1.5 if empty_result - 1.0 if huge_result - 0.5 * len(suspicious_flags) Args: parse_ok: Whether query parsed. execute_ok: Whether query executed. result_count: Number of results. answer_type_fit: Answer type fit score [0,1]. schema_fit: Schema fit score [0,1]. suspicious_flags: List of suspicious flag strings. weights: Scoring weights dict. Returns: Total validation score. """ score = 0.0 if parse_ok: score += weights.get("parse_ok", 5.0) if execute_ok: score += weights.get("execute_ok", 5.0) score += weights.get("answer_type_fit", 2.0) * answer_type_fit score += weights.get("schema_fit", 2.0) * schema_fit if "timeout" in suspicious_flags: score += weights.get("timeout", -2.0) if "empty_result" in suspicious_flags: score += weights.get("empty_result", -1.5) if "huge_result" in suspicious_flags: score += weights.get("huge_result", -1.0) score += weights.get("suspicious_flag", -0.5) * len(suspicious_flags) return round(score, 4) def validate_candidate( candidate: CandidateQuery, request: QueryRequest, context: ContextPackage, dataset: DatasetConfig, runtime: RuntimeConfig, ) -> ValidationResult: """Validate a single candidate query. Runs all symbolic checks: - Parser check - Endpoint execution - Timeout check - Result count check - Answer type sanity check - Schema plausibility check Args: candidate: The candidate SPARQL query. request: The original query request. context: Context package. dataset: Dataset configuration. runtime: Runtime configuration. Returns: ValidationResult with all check results and score. """ flags: list[str] = [] query = candidate.query # 1. Parser check parse_ok, parse_error = parse_check(query) if not parse_ok: flags.append("parse_fail") return ValidationResult( candidate_id=candidate.candidate_id, parse_ok=False, execute_ok=False, timeout=False, execution_error=parse_error, result_count=None, result_preview=[], answer_type_fit=0.0, schema_fit=0.0, suspicious_flags=flags, score=compute_validation_score( False, False, None, 0.0, 0.0, flags, runtime.selection_weights, ), ) # 2. Endpoint execution execute_ok, results, result_count, exec_error, timed_out = execute_query( query, dataset.endpoint_url, runtime.request_timeout_sec ) if timed_out: flags.append("timeout") if not execute_ok: flags.append("execute_fail") if result_count is not None: if result_count == 0: flags.append("empty_result") elif result_count > _HUGE_RESULT_THRESHOLD: flags.append("huge_result") # 3. Answer type check answer_type_hint = context.answer_type_hint or "select" at_fit = score_answer_type_fit(request.question, query, answer_type_hint) query_form = _detect_query_form(query) if answer_type_hint != query_form and query_form != "unknown": # Only flag if there's a clear mismatch if not (answer_type_hint == "count" and query_form == "select"): flags.append("form_mismatch") # 4. Schema fit s_fit = score_schema_fit(query, context) # 5. Compute score score = compute_validation_score( parse_ok, execute_ok, result_count, at_fit, s_fit, flags, runtime.selection_weights, ) return ValidationResult( candidate_id=candidate.candidate_id, parse_ok=parse_ok, execute_ok=execute_ok, timeout=timed_out, execution_error=exec_error, result_count=result_count, result_preview=results, answer_type_fit=at_fit, schema_fit=s_fit, suspicious_flags=flags, score=score, ) def validate_all( candidates: list[CandidateQuery], request: QueryRequest, context: ContextPackage, dataset: DatasetConfig, runtime: RuntimeConfig, ) -> list[ValidationResult]: """Validate all candidate queries. Args: candidates: List of candidate queries. request: The original query request. context: Context package. dataset: Dataset configuration. runtime: Runtime configuration. Returns: List of ValidationResult objects, one per candidate. """ results = [] for candidate in candidates: logger.info("Validating candidate %s", candidate.candidate_id) result = validate_candidate(candidate, request, context, dataset, runtime) logger.info( "Candidate %s: score=%.2f, flags=%s", candidate.candidate_id, result.score, result.suspicious_flags, ) results.append(result) return results