nl-sql / scripts /run_selfcon_retry.py
liovina's picture
Deploy NL_SQL HEAD to HF Space
4b4ff9e verified
Raw
History Blame Contribute Delete
11.3 kB
"""Targeted self-consistency retry on baseline failures with Mistral codestral.
For each failing question, runs the production G pipeline N times at distinct
temperatures (0.2, 0.4, 0.6, 0.8 by default), executes each candidate, and
votes via the largest fingerprint cluster (ties → highest confidence). Output
is voting-shaped for `merge_voting_rescues.py`.
Same model (Mistral codestral) — wins beyond ~1-2 are unlikely because
voting same-model against itself plateaus, but it's a free-tier sanity probe.
Usage:
uv run python scripts/run_selfcon_retry.py \
--baseline eval/reports/2026-05-13/hybrid+multi-vote+critique-v4.json \
--out eval/reports/2026-05-13/selfcon-retry.json
uv run python scripts/run_selfcon_retry.py \
--baseline eval/reports/2026-05-22/v20-kimi-k2-thinking-merged.json \
--out eval/reports/2026-05-22/selfcon-qid1399.json --only-qids 1399
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
from nl_sql.agent.graph import PipelineConfig, build_pipeline, run_pipeline
from nl_sql.config import get_settings
from nl_sql.db.registry import get_default_registry
from nl_sql.eval.dataset import load_bird_mini_dev
from nl_sql.eval.metrics.execution_accuracy import compare_results
from nl_sql.eval.runner import _compose_question, _execute_gold
from nl_sql.eval.self_consistency import Candidate, vote
from nl_sql.execution.runner import execute_validated
from nl_sql.llm.cache import CachingEmbeddingProvider, CachingLLMProvider
from nl_sql.llm.providers.base import (
EmbedRequest,
EmbedResponse,
GenerateRequest,
GenerateResponse,
ProviderError,
)
from nl_sql.llm.providers.mistral import MistralProvider
from nl_sql.schema_index.indexer import SchemaIndex
class RotatingMistralProvider:
"""Round-robin wrapper across N MistralProvider instances (different API keys).
On a 429 / rate-limit error, advances to the next key and retries. After a
full rotation without success, applies escalating backoff (5s * extra-attempts)
and keeps trying up to 2*N attempts before surrendering.
"""
name = "mistral"
def __init__(self, providers: list[MistralProvider]) -> None:
if not providers:
raise ProviderError("RotatingMistralProvider requires >=1 provider")
self._providers = providers
self._idx = 0
self.model = providers[0].model
self.embed_model = providers[0].embed_model
def _is_rate_limit(self, err: Exception) -> bool:
msg = str(err)
return "429" in msg or "Rate limit" in msg or "rate_limited" in msg
def _advance(self) -> None:
self._idx = (self._idx + 1) % len(self._providers)
def generate(self, req: GenerateRequest) -> GenerateResponse:
n = len(self._providers)
last_err: Exception | None = None
for attempt in range(n * 2):
prov = self._providers[self._idx]
try:
return prov.generate(req)
except ProviderError as exc:
if not self._is_rate_limit(exc):
raise
last_err = exc
self._advance()
if attempt >= n - 1:
time.sleep(5.0 * (attempt - n + 2))
raise ProviderError(f"all {n} keys rate-limited: {last_err}")
def embed(self, req: EmbedRequest) -> EmbedResponse:
n = len(self._providers)
last_err: Exception | None = None
for _ in range(n):
try:
return self._providers[self._idx].embed(req)
except ProviderError as exc:
if not self._is_rate_limit(exc):
raise
last_err = exc
self._advance()
raise ProviderError(f"all {n} keys rate-limited for embed: {last_err}")
def main() -> int:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--baseline", type=Path, required=True)
p.add_argument("--bird-root", type=Path, default=Path("data/bird_mini_dev/MINIDEV"))
p.add_argument(
"--only-qids",
default="",
help="comma-separated baseline failure qids to retry exactly, preserving argument order",
)
p.add_argument("--temperatures", nargs="+", type=float, default=[0.2, 0.4, 0.6, 0.8])
p.add_argument("--gen-model", default="codestral-latest", help="Mistral model id")
p.add_argument(
"--sleep-between",
type=float,
default=0.0,
help="seconds between pipeline calls (use for mistral-large rate limits)",
)
p.add_argument(
"--api-keys",
default=None,
help="CSV of Mistral API keys for round-robin rotation. Default: settings.mistral_api_key.",
)
p.add_argument("--out", type=Path, required=True)
args = p.parse_args()
baseline = json.loads(args.baseline.read_text(encoding="utf-8"))
fails = [r for r in baseline["records"] if not r.get("match")]
try:
only_qids = [int(x) for x in args.only_qids.split(",") if x.strip()]
except ValueError:
print("[error] invalid --only-qids: expected comma-separated integers", file=sys.stderr)
return 3
if only_qids:
fails_by_qid = {int(r["question_id"]): r for r in fails}
missing_qids = [qid for qid in only_qids if qid not in fails_by_qid]
if missing_qids:
print(f"[error] qids not found in baseline failures: {missing_qids}", file=sys.stderr)
return 3
fails = [fails_by_qid[qid] for qid in only_qids]
settings = get_settings()
if args.api_keys:
keys = [k.strip() for k in args.api_keys.split(",") if k.strip()]
else:
keys = [settings.mistral_api_key]
if not keys or not keys[0]:
print("[error] no Mistral API keys provided", file=sys.stderr)
return 1
print(
f"[info] {len(fails)} failures, temps={args.temperatures}, model={args.gen_model}, keys={len(keys)}",
file=sys.stderr,
)
examples = {e.question_id: e for e in load_bird_mini_dev(args.bird_root)}
registry = get_default_registry()
gen_providers = [MistralProvider(api_key=k, gen_model=args.gen_model) for k in keys]
mistral = RotatingMistralProvider(gen_providers) if len(keys) > 1 else gen_providers[0]
sql_prov = CachingLLMProvider(mistral, cache_dir=settings.llm_cache_dir)
embed_providers = [MistralProvider(api_key=k) for k in keys]
emb_base = RotatingMistralProvider(embed_providers) if len(keys) > 1 else embed_providers[0]
emb = CachingEmbeddingProvider(emb_base, cache_dir=settings.llm_cache_dir)
idx = SchemaIndex(persist_dir="chroma_data", embedder=emb)
pipelines = [
build_pipeline(
PipelineConfig(
sql_provider=sql_prov,
explain_provider=sql_prov,
schema_index=idx,
registry=registry,
fewshot_top_k=3,
sort_schema_block=True,
cross_db_fewshot=True,
verify_retry_on_empty=False,
sql_temperature=t,
)
)
for t in args.temperatures
]
records = []
rescued = 0
regressed = 0
same = 0
for i, br in enumerate(fails, 1):
qid = br["question_id"]
ex = examples.get(qid)
if ex is None:
continue
spec = registry.get(ex.registry_db_id)
engine = spec.make_engine()
try:
t0 = time.perf_counter()
candidates = []
for pipeline, temp in zip(pipelines, args.temperatures, strict=True):
try:
r = run_pipeline(
pipeline,
question=_compose_question(ex),
db_id=ex.registry_db_id,
dialect="sqlite",
)
candidates.append(Candidate(result=r, temperature=temp))
except Exception as exc:
print(f"[{i:3d}/{len(fails)}] qid={qid} T={temp} EXC: {exc}", file=sys.stderr)
if args.sleep_between > 0:
time.sleep(args.sleep_between)
if not candidates:
continue
winner = vote(candidates)
elapsed = (time.perf_counter() - t0) * 1000.0
alt_sql = winner.result.sql or ""
try:
outcome = execute_validated(
engine,
alt_sql,
dialect="sqlite",
statement_timeout_ms=30_000,
row_cap=10_000,
)
alt_rows = list(outcome.result.rows) if outcome.result else []
except Exception:
alt_rows = []
try:
gold_rows, _ = _execute_gold(
engine, ex.sql, statement_timeout_ms=30_000, row_cap=10_000
)
except Exception:
gold_rows = []
alt_cmp = compare_results(gold_rows, alt_rows, gold_sql=ex.sql)
alt_match = bool(alt_cmp.match)
if alt_match and not br.get("match"):
rescued += 1
tag = "RESCUE"
elif br.get("match") and not alt_match:
regressed += 1
tag = "regression"
else:
same += 1
tag = "same"
records.append(
{
"question_id": qid,
"db_id": ex.db_id,
"difficulty": ex.difficulty,
"question": ex.question,
"gold_sql": ex.sql,
"baseline_pred": br["pred_sql"],
"alt_pred": alt_sql,
"alt_confidence": getattr(winner.result, "confidence", None),
"winner_temperature": winner.temperature,
"baseline_match": bool(br.get("match")),
"alt_match": alt_match,
"vote_match": alt_match,
"vote_source": "self-consistency",
"elapsed_ms": elapsed,
}
)
print(
f"[{i:3d}/{len(fails)}] qid={qid} {ex.difficulty:11s} {tag} T_win={winner.temperature:.1f} ({elapsed:.0f}ms)",
file=sys.stderr,
)
finally:
engine.dispose()
print("\n=== self-consistency retry summary ===", file=sys.stderr)
print(f" cases: {len(records)}", file=sys.stderr)
print(f" rescued: {rescued}", file=sys.stderr)
print(f" regressed: {regressed}", file=sys.stderr)
print(f" same: {same}", file=sys.stderr)
args.out.parent.mkdir(parents=True, exist_ok=True)
args.out.write_text(
json.dumps(
{
"alt_model": f"{args.gen_model}+self-consistency",
"temperatures": list(args.temperatures),
"summary": {"voted_better": rescued, "voted_worse": regressed, "voted_same": same},
"records": records,
},
indent=2,
),
encoding="utf-8",
)
return 0
if __name__ == "__main__":
raise SystemExit(main())