#!/usr/bin/env python3 """Minimal RAG-style retrieval + simple faithfulness check (Horizon 1 short-term C). Chunks a FAQ markdown corpus by `##` sections, embeds with TinyModelRuntime, retrieves top matches for a query, and reports **keyword overlap** in the top hit as a cheap faithfulness proxy (not neural entailment). Optional **--show-train-routing** prints Phase 2 **`routing`** notes from the checkpoint's **eval_report.json** (same helper as **embeddings_smoke_test** / **horizon1_route_then_retrieve**).""" from __future__ import annotations import argparse import re import sys from pathlib import Path from typing import Any _scripts = Path(__file__).resolve().parent if str(_scripts) not in sys.path: sys.path.insert(0, str(_scripts)) from eval_report_routing import maybe_print_routing_section _PROG = "rag_faq_smoke" _STOP = frozenset( "a an the to of and or for in on at is are was be as it with from by not" .split() ) _REPO = Path(__file__).resolve().parent.parent # When --model is omitted, first existing dir wins; else published Hub weights. _DEFAULT_MODEL_DIRS = ( "artifacts/horizon1/three-tasks/ag_news", "artifacts/phase1/runs/smoke/ag_news/scratch", ".tmp/TinyModel-local", ".tmp/horizon1-verify-a", ) _DEFAULT_HUB = "HyperlinksSpace/TinyModel1" def _pick_model(explicit: str | None) -> str: """Resolve local checkpoint dir, or a Hugging Face model id (namespace/name).""" if explicit is None: for rel in _DEFAULT_MODEL_DIRS: d = _REPO / rel if (d / "config.json").is_file(): return str(d.resolve()) return _DEFAULT_HUB p = Path(explicit) for d in (p.resolve(), (_REPO / explicit).resolve()): if d.is_dir() and (d / "config.json").is_file(): return str(d) if p.exists() or (_REPO / explicit).exists(): print( f"Not a model directory (expected config.json): {explicit!r}\n" "Train first, e.g.:\n" " python scripts/train_tinymodel1_agnews.py --output-dir .tmp/rag-encoder " "--max-train-samples 200 --max-eval-samples 100 --epochs 1 --batch-size 8 --seed 42", file=sys.stderr, ) raise SystemExit(1) return explicit # Hub id, e.g. HyperlinksSpace/TinyModel1 def build_parser() -> argparse.ArgumentParser: epilog = ( "Examples:\n" " python scripts/rag_faq_smoke.py\n" " python scripts/rag_faq_smoke.py --query \"How do I get a refund?\" --top-k 3\n" " python scripts/rag_faq_smoke.py --model artifacts/phase1/runs/smoke/ag_news/scratch " "--show-train-routing\n" "If --model is omitted, the first default checkpoint dir with config.json is used, " f"else {_DEFAULT_HUB!r} (see --model above)." ) p = argparse.ArgumentParser( prog=_PROG, description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter, epilog=epilog, ) p.add_argument( "--model", type=str, default=None, help=( "Trained checkpoint directory or Hugging Face model id. " f"If omitted, uses the first of {_DEFAULT_MODEL_DIRS} that contains config.json, " f"else {_DEFAULT_HUB!r}." ), ) p.add_argument( "--corpus", type=str, default="texts/rag_faq_corpus.md", help="Markdown file with ##-delimited chunks.", ) p.add_argument("--top-k", type=int, default=2) p.add_argument( "--semantic-only", action="store_true", help="Use only TinyModelRuntime.retrieve (stricter; tiny encoders may fail on short FAQ chunks).", ) p.add_argument( "--query", type=str, default=None, help=( "If set, run a single retrieval for this query and print top-k chunks with scores " "(citation-style index into the chunk list). Skips the built-in smoke assertions." ), ) p.add_argument( "--show-train-routing", action="store_true", help="Print eval_report.json top-level routing (Phase 2 notes) before retrieval output.", ) return p def parse_args() -> argparse.Namespace: return build_parser().parse_args() def load_chunks(corpus: Path) -> list[str]: text = corpus.read_text(encoding="utf-8") # `re.split` with a capture: [preamble, title1, body1, title2, body2, ...] parts = re.split(r"(?m)^##\s+(.+)$", text) chunks: list[str] = [] for idx in range(1, len(parts), 2): if idx + 1 >= len(parts): break title = parts[idx].strip() body = parts[idx + 1].strip() if body: chunks.append(f"{title}\n{body}") return chunks if chunks else [text.strip()] def tokenize(s: str) -> set[str]: return {w.lower() for w in re.findall(r"[A-Za-z0-9']+", s) if w.lower() not in _STOP} def overlap_faithfulness(query: str, chunk: str) -> float: q, c = tokenize(query), tokenize(chunk) if not q: return 0.0 return len(q & c) / max(len(q), 1) def lex_substring_score(query: str, chunk: str) -> float: """Cheap overlap: fraction of 3+ char alphanumeric query tokens that appear as substrings.""" cl = chunk.lower() hit = tot = 0 for w in re.findall(r"[a-z0-9]+", query.lower()): if len(w) < 3: continue tot += 1 if w in cl: hit += 1 return hit / max(tot, 1) def hybrid_retrieve( rt: TinyModelRuntime, query: str, chunks: list[str], *, top_k: int, embed_weight: float = 0.45, ) -> list[tuple[float, int, str]]: """Combine cosine (encoder) + lexical overlap so tiny scratch encoders still rank sensible FAQ chunks.""" if not chunks: return [] texts = [query, *chunks] embs = rt.embed(texts, normalize=True) qe = embs[0:1] ce = embs[1:] cos = (qe @ ce.T).squeeze(0) ranked: list[tuple[float, int]] = [] for i, ch in enumerate(chunks): lex = lex_substring_score(query, ch) s = embed_weight * float(cos[i]) + (1.0 - embed_weight) * lex ranked.append((s, i)) ranked.sort(key=lambda x: -x[0]) out: list[tuple[float, int, str]] = [] for s, i in ranked[:top_k]: out.append((s, i, chunks[i])) return out def main() -> None: args = parse_args() model_id = _pick_model(args.model) if args.model is None: print(f"rag_faq_smoke: using --model {model_id!r} (set explicitly to override).", file=sys.stderr) corpus = Path(args.corpus) if not corpus.is_file(): print(f"Corpus not found: {corpus}", file=sys.stderr) raise SystemExit(1) chunks = load_chunks(corpus) maybe_print_routing_section( model_id, enabled=args.show_train_routing, prog=_PROG, ) from tinymodel_runtime import TinyModelRuntime rt = TinyModelRuntime(model_id, device="cpu", max_length=128) if args.query: q = args.query.strip() print("=== RAG FAQ (single query) ===\n") print(f"model={model_id!r}\ncorpus={corpus}\nquery={q!r}\n") if args.semantic_only: hits = rt.retrieve(q, chunks, top_k=args.top_k) for rank, h in enumerate(hits, 1): prev = h.text[:240].replace("\n", " ") print(f" #{rank} idx={h.index} score={h.score:.4f} {prev!r}...") else: hr = hybrid_retrieve(rt, q, chunks, top_k=args.top_k) for rank, (score, idx, text) in enumerate(hr, 1): prev = text[:240].replace("\n", " ") print(f" #{rank} idx={idx} hybrid_score={score:.4f} {prev!r}...") return print("=== RAG FAQ smoke (retrieval) ===\n") # (query, substring that must appear in top-1 chunk for a pass — citation-style check) samples: list[tuple[str, str]] = [ ("How do I get a refund for my order?", "refund"), ("I see an unauthorized login on my account", "password"), ('My package tracking says exception, what do I do?', "exception"), ] all_ok = True for q, must in samples: if args.semantic_only: hits = rt.retrieve(q, chunks, top_k=args.top_k) top_text = hits[0].text top_score = hits[0].score else: hr = hybrid_retrieve(rt, q, chunks, top_k=args.top_k) top_score, _idx, top_text = hr[0] f = overlap_faithfulness(q, top_text) cited = must.lower() in top_text.lower() ok = cited or f >= 0.12 if not ok: all_ok = False status = "ok" if ok else "fail" print(f"Q: {q}") print( f" top hybrid/semantic score={top_score:.4f} keyword_overlap={f:.2f} " f"contains({must!r})={cited} [{status}]" ) safe = top_text[:200].replace(chr(10), " ").encode("ascii", "replace").decode("ascii") print(f" chunk preview: {safe}...") print() if all_ok: print( "RAG FAQ smoke: passed (default: hybrid lexical + encoder; use --semantic-only to stress pure embedding retrieval).", ) else: print( "RAG smoke failed: re-train the encoder, use a larger/HF model, or add training pairs.", file=sys.stderr, ) raise SystemExit(1) if __name__ == "__main__": main()