"""Targeted self-consistency retry on baseline failures with Mistral codestral. For each failing question, runs the production G pipeline N times at distinct temperatures (0.2, 0.4, 0.6, 0.8 by default), executes each candidate, and votes via the largest fingerprint cluster (ties → highest confidence). Output is voting-shaped for `merge_voting_rescues.py`. Same model (Mistral codestral) — wins beyond ~1-2 are unlikely because voting same-model against itself plateaus, but it's a free-tier sanity probe. Usage: uv run python scripts/run_selfcon_retry.py \ --baseline eval/reports/2026-05-13/hybrid+multi-vote+critique-v4.json \ --out eval/reports/2026-05-13/selfcon-retry.json uv run python scripts/run_selfcon_retry.py \ --baseline eval/reports/2026-05-22/v20-kimi-k2-thinking-merged.json \ --out eval/reports/2026-05-22/selfcon-qid1399.json --only-qids 1399 """ from __future__ import annotations import argparse import json import sys import time from pathlib import Path from nl_sql.agent.graph import PipelineConfig, build_pipeline, run_pipeline from nl_sql.config import get_settings from nl_sql.db.registry import get_default_registry from nl_sql.eval.dataset import load_bird_mini_dev from nl_sql.eval.metrics.execution_accuracy import compare_results from nl_sql.eval.runner import _compose_question, _execute_gold from nl_sql.eval.self_consistency import Candidate, vote from nl_sql.execution.runner import execute_validated from nl_sql.llm.cache import CachingEmbeddingProvider, CachingLLMProvider from nl_sql.llm.providers.base import ( EmbedRequest, EmbedResponse, GenerateRequest, GenerateResponse, ProviderError, ) from nl_sql.llm.providers.mistral import MistralProvider from nl_sql.schema_index.indexer import SchemaIndex class RotatingMistralProvider: """Round-robin wrapper across N MistralProvider instances (different API keys). On a 429 / rate-limit error, advances to the next key and retries. After a full rotation without success, applies escalating backoff (5s * extra-attempts) and keeps trying up to 2*N attempts before surrendering. """ name = "mistral" def __init__(self, providers: list[MistralProvider]) -> None: if not providers: raise ProviderError("RotatingMistralProvider requires >=1 provider") self._providers = providers self._idx = 0 self.model = providers[0].model self.embed_model = providers[0].embed_model def _is_rate_limit(self, err: Exception) -> bool: msg = str(err) return "429" in msg or "Rate limit" in msg or "rate_limited" in msg def _advance(self) -> None: self._idx = (self._idx + 1) % len(self._providers) def generate(self, req: GenerateRequest) -> GenerateResponse: n = len(self._providers) last_err: Exception | None = None for attempt in range(n * 2): prov = self._providers[self._idx] try: return prov.generate(req) except ProviderError as exc: if not self._is_rate_limit(exc): raise last_err = exc self._advance() if attempt >= n - 1: time.sleep(5.0 * (attempt - n + 2)) raise ProviderError(f"all {n} keys rate-limited: {last_err}") def embed(self, req: EmbedRequest) -> EmbedResponse: n = len(self._providers) last_err: Exception | None = None for _ in range(n): try: return self._providers[self._idx].embed(req) except ProviderError as exc: if not self._is_rate_limit(exc): raise last_err = exc self._advance() raise ProviderError(f"all {n} keys rate-limited for embed: {last_err}") def main() -> int: p = argparse.ArgumentParser(description=__doc__) p.add_argument("--baseline", type=Path, required=True) p.add_argument("--bird-root", type=Path, default=Path("data/bird_mini_dev/MINIDEV")) p.add_argument( "--only-qids", default="", help="comma-separated baseline failure qids to retry exactly, preserving argument order", ) p.add_argument("--temperatures", nargs="+", type=float, default=[0.2, 0.4, 0.6, 0.8]) p.add_argument("--gen-model", default="codestral-latest", help="Mistral model id") p.add_argument( "--sleep-between", type=float, default=0.0, help="seconds between pipeline calls (use for mistral-large rate limits)", ) p.add_argument( "--api-keys", default=None, help="CSV of Mistral API keys for round-robin rotation. Default: settings.mistral_api_key.", ) p.add_argument("--out", type=Path, required=True) args = p.parse_args() baseline = json.loads(args.baseline.read_text(encoding="utf-8")) fails = [r for r in baseline["records"] if not r.get("match")] try: only_qids = [int(x) for x in args.only_qids.split(",") if x.strip()] except ValueError: print("[error] invalid --only-qids: expected comma-separated integers", file=sys.stderr) return 3 if only_qids: fails_by_qid = {int(r["question_id"]): r for r in fails} missing_qids = [qid for qid in only_qids if qid not in fails_by_qid] if missing_qids: print(f"[error] qids not found in baseline failures: {missing_qids}", file=sys.stderr) return 3 fails = [fails_by_qid[qid] for qid in only_qids] settings = get_settings() if args.api_keys: keys = [k.strip() for k in args.api_keys.split(",") if k.strip()] else: keys = [settings.mistral_api_key] if not keys or not keys[0]: print("[error] no Mistral API keys provided", file=sys.stderr) return 1 print( f"[info] {len(fails)} failures, temps={args.temperatures}, model={args.gen_model}, keys={len(keys)}", file=sys.stderr, ) examples = {e.question_id: e for e in load_bird_mini_dev(args.bird_root)} registry = get_default_registry() gen_providers = [MistralProvider(api_key=k, gen_model=args.gen_model) for k in keys] mistral = RotatingMistralProvider(gen_providers) if len(keys) > 1 else gen_providers[0] sql_prov = CachingLLMProvider(mistral, cache_dir=settings.llm_cache_dir) embed_providers = [MistralProvider(api_key=k) for k in keys] emb_base = RotatingMistralProvider(embed_providers) if len(keys) > 1 else embed_providers[0] emb = CachingEmbeddingProvider(emb_base, cache_dir=settings.llm_cache_dir) idx = SchemaIndex(persist_dir="chroma_data", embedder=emb) pipelines = [ build_pipeline( PipelineConfig( sql_provider=sql_prov, explain_provider=sql_prov, schema_index=idx, registry=registry, fewshot_top_k=3, sort_schema_block=True, cross_db_fewshot=True, verify_retry_on_empty=False, sql_temperature=t, ) ) for t in args.temperatures ] records = [] rescued = 0 regressed = 0 same = 0 for i, br in enumerate(fails, 1): qid = br["question_id"] ex = examples.get(qid) if ex is None: continue spec = registry.get(ex.registry_db_id) engine = spec.make_engine() try: t0 = time.perf_counter() candidates = [] for pipeline, temp in zip(pipelines, args.temperatures, strict=True): try: r = run_pipeline( pipeline, question=_compose_question(ex), db_id=ex.registry_db_id, dialect="sqlite", ) candidates.append(Candidate(result=r, temperature=temp)) except Exception as exc: print(f"[{i:3d}/{len(fails)}] qid={qid} T={temp} EXC: {exc}", file=sys.stderr) if args.sleep_between > 0: time.sleep(args.sleep_between) if not candidates: continue winner = vote(candidates) elapsed = (time.perf_counter() - t0) * 1000.0 alt_sql = winner.result.sql or "" try: outcome = execute_validated( engine, alt_sql, dialect="sqlite", statement_timeout_ms=30_000, row_cap=10_000, ) alt_rows = list(outcome.result.rows) if outcome.result else [] except Exception: alt_rows = [] try: gold_rows, _ = _execute_gold( engine, ex.sql, statement_timeout_ms=30_000, row_cap=10_000 ) except Exception: gold_rows = [] alt_cmp = compare_results(gold_rows, alt_rows, gold_sql=ex.sql) alt_match = bool(alt_cmp.match) if alt_match and not br.get("match"): rescued += 1 tag = "RESCUE" elif br.get("match") and not alt_match: regressed += 1 tag = "regression" else: same += 1 tag = "same" records.append( { "question_id": qid, "db_id": ex.db_id, "difficulty": ex.difficulty, "question": ex.question, "gold_sql": ex.sql, "baseline_pred": br["pred_sql"], "alt_pred": alt_sql, "alt_confidence": getattr(winner.result, "confidence", None), "winner_temperature": winner.temperature, "baseline_match": bool(br.get("match")), "alt_match": alt_match, "vote_match": alt_match, "vote_source": "self-consistency", "elapsed_ms": elapsed, } ) print( f"[{i:3d}/{len(fails)}] qid={qid} {ex.difficulty:11s} {tag} T_win={winner.temperature:.1f} ({elapsed:.0f}ms)", file=sys.stderr, ) finally: engine.dispose() print("\n=== self-consistency retry summary ===", file=sys.stderr) print(f" cases: {len(records)}", file=sys.stderr) print(f" rescued: {rescued}", file=sys.stderr) print(f" regressed: {regressed}", file=sys.stderr) print(f" same: {same}", file=sys.stderr) args.out.parent.mkdir(parents=True, exist_ok=True) args.out.write_text( json.dumps( { "alt_model": f"{args.gen_model}+self-consistency", "temperatures": list(args.temperatures), "summary": {"voted_better": rescued, "voted_worse": regressed, "voted_same": same}, "records": records, }, indent=2, ), encoding="utf-8", ) return 0 if __name__ == "__main__": raise SystemExit(main())