Spaces:
Sleeping
Sleeping
File size: 9,400 Bytes
61ff229 4339a77 61ff229 4339a77 61ff229 4339a77 61ff229 4339a77 61ff229 4339a77 61ff229 4339a77 61ff229 4339a77 61ff229 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 | #!/usr/bin/env python3
"""Minimal RAG-style retrieval + simple faithfulness check (Horizon 1 short-term C).
Chunks a FAQ markdown corpus by `##` sections, embeds with TinyModelRuntime, retrieves top
matches for a query, and reports **keyword overlap** in the top hit as a cheap faithfulness
proxy (not neural entailment). Optional **--show-train-routing** prints Phase 2 **`routing`**
notes from the checkpoint's **eval_report.json** (same helper as **embeddings_smoke_test** /
**horizon1_route_then_retrieve**)."""
from __future__ import annotations
import argparse
import re
import sys
from pathlib import Path
from typing import Any
_scripts = Path(__file__).resolve().parent
if str(_scripts) not in sys.path:
sys.path.insert(0, str(_scripts))
from eval_report_routing import maybe_print_routing_section
_PROG = "rag_faq_smoke"
_STOP = frozenset(
"a an the to of and or for in on at is are was be as it with from by not"
.split()
)
_REPO = Path(__file__).resolve().parent.parent
# When --model is omitted, first existing dir wins; else published Hub weights.
_DEFAULT_MODEL_DIRS = (
"artifacts/horizon1/three-tasks/ag_news",
"artifacts/phase1/runs/smoke/ag_news/scratch",
".tmp/TinyModel-local",
".tmp/horizon1-verify-a",
)
_DEFAULT_HUB = "HyperlinksSpace/TinyModel1"
def _pick_model(explicit: str | None) -> str:
"""Resolve local checkpoint dir, or a Hugging Face model id (namespace/name)."""
if explicit is None:
for rel in _DEFAULT_MODEL_DIRS:
d = _REPO / rel
if (d / "config.json").is_file():
return str(d.resolve())
return _DEFAULT_HUB
p = Path(explicit)
for d in (p.resolve(), (_REPO / explicit).resolve()):
if d.is_dir() and (d / "config.json").is_file():
return str(d)
if p.exists() or (_REPO / explicit).exists():
print(
f"Not a model directory (expected config.json): {explicit!r}\n"
"Train first, e.g.:\n"
" python scripts/train_tinymodel1_agnews.py --output-dir .tmp/rag-encoder "
"--max-train-samples 200 --max-eval-samples 100 --epochs 1 --batch-size 8 --seed 42",
file=sys.stderr,
)
raise SystemExit(1)
return explicit # Hub id, e.g. HyperlinksSpace/TinyModel1
def build_parser() -> argparse.ArgumentParser:
epilog = (
"Examples:\n"
" python scripts/rag_faq_smoke.py\n"
" python scripts/rag_faq_smoke.py --query \"How do I get a refund?\" --top-k 3\n"
" python scripts/rag_faq_smoke.py --model artifacts/phase1/runs/smoke/ag_news/scratch "
"--show-train-routing\n"
"If --model is omitted, the first default checkpoint dir with config.json is used, "
f"else {_DEFAULT_HUB!r} (see --model above)."
)
p = argparse.ArgumentParser(
prog=_PROG,
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=epilog,
)
p.add_argument(
"--model",
type=str,
default=None,
help=(
"Trained checkpoint directory or Hugging Face model id. "
f"If omitted, uses the first of {_DEFAULT_MODEL_DIRS} that contains config.json, "
f"else {_DEFAULT_HUB!r}."
),
)
p.add_argument(
"--corpus",
type=str,
default="texts/rag_faq_corpus.md",
help="Markdown file with ##-delimited chunks.",
)
p.add_argument("--top-k", type=int, default=2)
p.add_argument(
"--semantic-only",
action="store_true",
help="Use only TinyModelRuntime.retrieve (stricter; tiny encoders may fail on short FAQ chunks).",
)
p.add_argument(
"--query",
type=str,
default=None,
help=(
"If set, run a single retrieval for this query and print top-k chunks with scores "
"(citation-style index into the chunk list). Skips the built-in smoke assertions."
),
)
p.add_argument(
"--show-train-routing",
action="store_true",
help="Print eval_report.json top-level routing (Phase 2 notes) before retrieval output.",
)
return p
def parse_args() -> argparse.Namespace:
return build_parser().parse_args()
def load_chunks(corpus: Path) -> list[str]:
text = corpus.read_text(encoding="utf-8")
# `re.split` with a capture: [preamble, title1, body1, title2, body2, ...]
parts = re.split(r"(?m)^##\s+(.+)$", text)
chunks: list[str] = []
for idx in range(1, len(parts), 2):
if idx + 1 >= len(parts):
break
title = parts[idx].strip()
body = parts[idx + 1].strip()
if body:
chunks.append(f"{title}\n{body}")
return chunks if chunks else [text.strip()]
def tokenize(s: str) -> set[str]:
return {w.lower() for w in re.findall(r"[A-Za-z0-9']+", s) if w.lower() not in _STOP}
def overlap_faithfulness(query: str, chunk: str) -> float:
q, c = tokenize(query), tokenize(chunk)
if not q:
return 0.0
return len(q & c) / max(len(q), 1)
def lex_substring_score(query: str, chunk: str) -> float:
"""Cheap overlap: fraction of 3+ char alphanumeric query tokens that appear as substrings."""
cl = chunk.lower()
hit = tot = 0
for w in re.findall(r"[a-z0-9]+", query.lower()):
if len(w) < 3:
continue
tot += 1
if w in cl:
hit += 1
return hit / max(tot, 1)
def hybrid_retrieve(
rt: TinyModelRuntime,
query: str,
chunks: list[str],
*,
top_k: int,
embed_weight: float = 0.45,
) -> list[tuple[float, int, str]]:
"""Combine cosine (encoder) + lexical overlap so tiny scratch encoders still rank sensible FAQ chunks."""
if not chunks:
return []
texts = [query, *chunks]
embs = rt.embed(texts, normalize=True)
qe = embs[0:1]
ce = embs[1:]
cos = (qe @ ce.T).squeeze(0)
ranked: list[tuple[float, int]] = []
for i, ch in enumerate(chunks):
lex = lex_substring_score(query, ch)
s = embed_weight * float(cos[i]) + (1.0 - embed_weight) * lex
ranked.append((s, i))
ranked.sort(key=lambda x: -x[0])
out: list[tuple[float, int, str]] = []
for s, i in ranked[:top_k]:
out.append((s, i, chunks[i]))
return out
def main() -> None:
args = parse_args()
model_id = _pick_model(args.model)
if args.model is None:
print(f"rag_faq_smoke: using --model {model_id!r} (set explicitly to override).", file=sys.stderr)
corpus = Path(args.corpus)
if not corpus.is_file():
print(f"Corpus not found: {corpus}", file=sys.stderr)
raise SystemExit(1)
chunks = load_chunks(corpus)
maybe_print_routing_section(
model_id, enabled=args.show_train_routing, prog=_PROG,
)
from tinymodel_runtime import TinyModelRuntime
rt = TinyModelRuntime(model_id, device="cpu", max_length=128)
if args.query:
q = args.query.strip()
print("=== RAG FAQ (single query) ===\n")
print(f"model={model_id!r}\ncorpus={corpus}\nquery={q!r}\n")
if args.semantic_only:
hits = rt.retrieve(q, chunks, top_k=args.top_k)
for rank, h in enumerate(hits, 1):
prev = h.text[:240].replace("\n", " ")
print(f" #{rank} idx={h.index} score={h.score:.4f} {prev!r}...")
else:
hr = hybrid_retrieve(rt, q, chunks, top_k=args.top_k)
for rank, (score, idx, text) in enumerate(hr, 1):
prev = text[:240].replace("\n", " ")
print(f" #{rank} idx={idx} hybrid_score={score:.4f} {prev!r}...")
return
print("=== RAG FAQ smoke (retrieval) ===\n")
# (query, substring that must appear in top-1 chunk for a pass — citation-style check)
samples: list[tuple[str, str]] = [
("How do I get a refund for my order?", "refund"),
("I see an unauthorized login on my account", "password"),
('My package tracking says exception, what do I do?', "exception"),
]
all_ok = True
for q, must in samples:
if args.semantic_only:
hits = rt.retrieve(q, chunks, top_k=args.top_k)
top_text = hits[0].text
top_score = hits[0].score
else:
hr = hybrid_retrieve(rt, q, chunks, top_k=args.top_k)
top_score, _idx, top_text = hr[0]
f = overlap_faithfulness(q, top_text)
cited = must.lower() in top_text.lower()
ok = cited or f >= 0.12
if not ok:
all_ok = False
status = "ok" if ok else "fail"
print(f"Q: {q}")
print(
f" top hybrid/semantic score={top_score:.4f} keyword_overlap={f:.2f} "
f"contains({must!r})={cited} [{status}]"
)
safe = top_text[:200].replace(chr(10), " ").encode("ascii", "replace").decode("ascii")
print(f" chunk preview: {safe}...")
print()
if all_ok:
print(
"RAG FAQ smoke: passed (default: hybrid lexical + encoder; use --semantic-only to stress pure embedding retrieval).",
)
else:
print(
"RAG smoke failed: re-train the encoder, use a larger/HF model, or add training pairs.",
file=sys.stderr,
)
raise SystemExit(1)
if __name__ == "__main__":
main()
|