TinyModel1Space / scripts /rag_faq_smoke.py
anriltine's picture
Deploy TinyModel1Space from GitHub Actions
4339a77 verified
#!/usr/bin/env python3
"""Minimal RAG-style retrieval + simple faithfulness check (Horizon 1 short-term C).
Chunks a FAQ markdown corpus by `##` sections, embeds with TinyModelRuntime, retrieves top
matches for a query, and reports **keyword overlap** in the top hit as a cheap faithfulness
proxy (not neural entailment). Optional **--show-train-routing** prints Phase 2 **`routing`**
notes from the checkpoint's **eval_report.json** (same helper as **embeddings_smoke_test** /
**horizon1_route_then_retrieve**)."""
from __future__ import annotations
import argparse
import re
import sys
from pathlib import Path
from typing import Any
_scripts = Path(__file__).resolve().parent
if str(_scripts) not in sys.path:
sys.path.insert(0, str(_scripts))
from eval_report_routing import maybe_print_routing_section
_PROG = "rag_faq_smoke"
_STOP = frozenset(
"a an the to of and or for in on at is are was be as it with from by not"
.split()
)
_REPO = Path(__file__).resolve().parent.parent
# When --model is omitted, first existing dir wins; else published Hub weights.
_DEFAULT_MODEL_DIRS = (
"artifacts/horizon1/three-tasks/ag_news",
"artifacts/phase1/runs/smoke/ag_news/scratch",
".tmp/TinyModel-local",
".tmp/horizon1-verify-a",
)
_DEFAULT_HUB = "HyperlinksSpace/TinyModel1"
def _pick_model(explicit: str | None) -> str:
"""Resolve local checkpoint dir, or a Hugging Face model id (namespace/name)."""
if explicit is None:
for rel in _DEFAULT_MODEL_DIRS:
d = _REPO / rel
if (d / "config.json").is_file():
return str(d.resolve())
return _DEFAULT_HUB
p = Path(explicit)
for d in (p.resolve(), (_REPO / explicit).resolve()):
if d.is_dir() and (d / "config.json").is_file():
return str(d)
if p.exists() or (_REPO / explicit).exists():
print(
f"Not a model directory (expected config.json): {explicit!r}\n"
"Train first, e.g.:\n"
" python scripts/train_tinymodel1_agnews.py --output-dir .tmp/rag-encoder "
"--max-train-samples 200 --max-eval-samples 100 --epochs 1 --batch-size 8 --seed 42",
file=sys.stderr,
)
raise SystemExit(1)
return explicit # Hub id, e.g. HyperlinksSpace/TinyModel1
def build_parser() -> argparse.ArgumentParser:
epilog = (
"Examples:\n"
" python scripts/rag_faq_smoke.py\n"
" python scripts/rag_faq_smoke.py --query \"How do I get a refund?\" --top-k 3\n"
" python scripts/rag_faq_smoke.py --model artifacts/phase1/runs/smoke/ag_news/scratch "
"--show-train-routing\n"
"If --model is omitted, the first default checkpoint dir with config.json is used, "
f"else {_DEFAULT_HUB!r} (see --model above)."
)
p = argparse.ArgumentParser(
prog=_PROG,
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=epilog,
)
p.add_argument(
"--model",
type=str,
default=None,
help=(
"Trained checkpoint directory or Hugging Face model id. "
f"If omitted, uses the first of {_DEFAULT_MODEL_DIRS} that contains config.json, "
f"else {_DEFAULT_HUB!r}."
),
)
p.add_argument(
"--corpus",
type=str,
default="texts/rag_faq_corpus.md",
help="Markdown file with ##-delimited chunks.",
)
p.add_argument("--top-k", type=int, default=2)
p.add_argument(
"--semantic-only",
action="store_true",
help="Use only TinyModelRuntime.retrieve (stricter; tiny encoders may fail on short FAQ chunks).",
)
p.add_argument(
"--query",
type=str,
default=None,
help=(
"If set, run a single retrieval for this query and print top-k chunks with scores "
"(citation-style index into the chunk list). Skips the built-in smoke assertions."
),
)
p.add_argument(
"--show-train-routing",
action="store_true",
help="Print eval_report.json top-level routing (Phase 2 notes) before retrieval output.",
)
return p
def parse_args() -> argparse.Namespace:
return build_parser().parse_args()
def load_chunks(corpus: Path) -> list[str]:
text = corpus.read_text(encoding="utf-8")
# `re.split` with a capture: [preamble, title1, body1, title2, body2, ...]
parts = re.split(r"(?m)^##\s+(.+)$", text)
chunks: list[str] = []
for idx in range(1, len(parts), 2):
if idx + 1 >= len(parts):
break
title = parts[idx].strip()
body = parts[idx + 1].strip()
if body:
chunks.append(f"{title}\n{body}")
return chunks if chunks else [text.strip()]
def tokenize(s: str) -> set[str]:
return {w.lower() for w in re.findall(r"[A-Za-z0-9']+", s) if w.lower() not in _STOP}
def overlap_faithfulness(query: str, chunk: str) -> float:
q, c = tokenize(query), tokenize(chunk)
if not q:
return 0.0
return len(q & c) / max(len(q), 1)
def lex_substring_score(query: str, chunk: str) -> float:
"""Cheap overlap: fraction of 3+ char alphanumeric query tokens that appear as substrings."""
cl = chunk.lower()
hit = tot = 0
for w in re.findall(r"[a-z0-9]+", query.lower()):
if len(w) < 3:
continue
tot += 1
if w in cl:
hit += 1
return hit / max(tot, 1)
def hybrid_retrieve(
rt: TinyModelRuntime,
query: str,
chunks: list[str],
*,
top_k: int,
embed_weight: float = 0.45,
) -> list[tuple[float, int, str]]:
"""Combine cosine (encoder) + lexical overlap so tiny scratch encoders still rank sensible FAQ chunks."""
if not chunks:
return []
texts = [query, *chunks]
embs = rt.embed(texts, normalize=True)
qe = embs[0:1]
ce = embs[1:]
cos = (qe @ ce.T).squeeze(0)
ranked: list[tuple[float, int]] = []
for i, ch in enumerate(chunks):
lex = lex_substring_score(query, ch)
s = embed_weight * float(cos[i]) + (1.0 - embed_weight) * lex
ranked.append((s, i))
ranked.sort(key=lambda x: -x[0])
out: list[tuple[float, int, str]] = []
for s, i in ranked[:top_k]:
out.append((s, i, chunks[i]))
return out
def main() -> None:
args = parse_args()
model_id = _pick_model(args.model)
if args.model is None:
print(f"rag_faq_smoke: using --model {model_id!r} (set explicitly to override).", file=sys.stderr)
corpus = Path(args.corpus)
if not corpus.is_file():
print(f"Corpus not found: {corpus}", file=sys.stderr)
raise SystemExit(1)
chunks = load_chunks(corpus)
maybe_print_routing_section(
model_id, enabled=args.show_train_routing, prog=_PROG,
)
from tinymodel_runtime import TinyModelRuntime
rt = TinyModelRuntime(model_id, device="cpu", max_length=128)
if args.query:
q = args.query.strip()
print("=== RAG FAQ (single query) ===\n")
print(f"model={model_id!r}\ncorpus={corpus}\nquery={q!r}\n")
if args.semantic_only:
hits = rt.retrieve(q, chunks, top_k=args.top_k)
for rank, h in enumerate(hits, 1):
prev = h.text[:240].replace("\n", " ")
print(f" #{rank} idx={h.index} score={h.score:.4f} {prev!r}...")
else:
hr = hybrid_retrieve(rt, q, chunks, top_k=args.top_k)
for rank, (score, idx, text) in enumerate(hr, 1):
prev = text[:240].replace("\n", " ")
print(f" #{rank} idx={idx} hybrid_score={score:.4f} {prev!r}...")
return
print("=== RAG FAQ smoke (retrieval) ===\n")
# (query, substring that must appear in top-1 chunk for a pass — citation-style check)
samples: list[tuple[str, str]] = [
("How do I get a refund for my order?", "refund"),
("I see an unauthorized login on my account", "password"),
('My package tracking says exception, what do I do?', "exception"),
]
all_ok = True
for q, must in samples:
if args.semantic_only:
hits = rt.retrieve(q, chunks, top_k=args.top_k)
top_text = hits[0].text
top_score = hits[0].score
else:
hr = hybrid_retrieve(rt, q, chunks, top_k=args.top_k)
top_score, _idx, top_text = hr[0]
f = overlap_faithfulness(q, top_text)
cited = must.lower() in top_text.lower()
ok = cited or f >= 0.12
if not ok:
all_ok = False
status = "ok" if ok else "fail"
print(f"Q: {q}")
print(
f" top hybrid/semantic score={top_score:.4f} keyword_overlap={f:.2f} "
f"contains({must!r})={cited} [{status}]"
)
safe = top_text[:200].replace(chr(10), " ").encode("ascii", "replace").decode("ascii")
print(f" chunk preview: {safe}...")
print()
if all_ok:
print(
"RAG FAQ smoke: passed (default: hybrid lexical + encoder; use --semantic-only to stress pure embedding retrieval).",
)
else:
print(
"RAG smoke failed: re-train the encoder, use a larger/HF model, or add training pairs.",
file=sys.stderr,
)
raise SystemExit(1)
if __name__ == "__main__":
main()