tarekziade's picture
tarekziade HF Staff
higlight
e330b99
"""
ragstudio backend.
Single long-running HTTP server that owns the searchers (and therefore the
loaded model weights). Both the UI and the CLI talk to it over HTTP.
Routes:
GET / -> ui/static/index.html
GET /static/<file> -> ui static asset
GET /api/modalities -> {"modalities": [...], "groups": {...}}
GET /api/search/<modality>?q=&k= -> {modality, kind, query, hits:[{score,path}]}
GET /api/file?path= -> serve a file referenced in any index's
*_meta.json (with stale-path healing)
Usage:
python backend/server.py [--host 127.0.0.1] [--port 8000]
"""
import argparse
import json
import mimetypes
import os
import re
import sys
import threading
import time
import urllib.parse
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
SYNC_INTERVAL_S = 30
# Video meta entries are stored as "<path> @ <t>s"; strip the timestamp so the
# allow-list contains the underlying file path.
_VIDEO_FRAME_SUFFIX = re.compile(r" @ \d+(?:\.\d+)?s$")
def _strip_frame_suffix(entry: str) -> str:
return _VIDEO_FRAME_SUFFIX.sub("", entry)
REPO_ROOT = Path(__file__).resolve().parent.parent
INDEXING_DIR = REPO_ROOT / "indexing"
INDEX_DATA_DIR = INDEXING_DIR / "index_data"
UI_STATIC_DIR = REPO_ROOT / "ui" / "static"
def _source_root() -> Path | None:
"""Folder originally passed to build_index(), recorded by index.py."""
manifest = INDEX_DATA_DIR / "_source.txt"
if not manifest.exists():
return None
p = Path(manifest.read_text().strip())
return p if p.is_dir() else None
# Searchers use relative `index_data/`; cd into indexing/ so paths resolve.
os.chdir(INDEXING_DIR)
sys.path.insert(0, str(INDEXING_DIR))
from searchers import GROUPS, SEARCHERS # noqa: E402
def _modality_kind(name: str) -> str:
if name == "image":
return "image"
if name == "video":
return "video"
return "text"
def _load_allowed_paths() -> set[str]:
"""Union of every path referenced in any *_meta.json — the file allow-list."""
allowed: set[str] = set()
for meta in (INDEXING_DIR / "index_data").glob("*_meta.json"):
try:
data = json.loads(meta.read_text())
except Exception:
continue
if isinstance(data, list):
for entry in data:
if isinstance(entry, str):
allowed.add(_strip_frame_suffix(entry))
elif isinstance(entry, dict) and "path" in entry:
allowed.add(_strip_frame_suffix(entry["path"]))
return allowed
ALLOWED_PATHS = _load_allowed_paths()
# basename -> actual filesystem path, for healing stale absolute paths when
# indexed files have since been moved. Built lazily on first miss.
_BASENAME_INDEX: dict[str, str] | None = None
def _build_basename_index() -> dict[str, str]:
needed = {Path(p).name for p in ALLOWED_PATHS}
out: dict[str, str] = {}
seen: set[Path] = set()
for root in (INDEXING_DIR, REPO_ROOT):
try:
root = root.resolve()
except Exception:
continue
if root in seen:
continue
seen.add(root)
for path in root.rglob("*"):
if not path.is_file():
continue
name = path.name
if name in needed and name not in out:
out[name] = str(path)
return out
_EXCERPT_CONTEXT_LINES = 4
_EXCERPT_LINE_CHARS = 240
_EXCERPT_TOKEN_RE = re.compile(r"\w+")
# Split text into sentence-ish candidate units. Falls back to whole lines if
# the file already uses line breaks meaningfully.
_SENTENCE_SPLIT_RE = re.compile(r"(?<=[.!?])\s+(?=[A-Za-z0-9])")
def _split_candidates(text: str) -> list[str]:
by_line = [line.strip() for line in text.splitlines() if line.strip()]
# If newlines already chop the text into multiple pieces, trust them.
if len(by_line) >= 4:
return by_line
# Otherwise treat each line as a paragraph and split it into sentences.
out: list[str] = []
for line in by_line:
out.extend(s.strip() for s in _SENTENCE_SPLIT_RE.split(line) if s.strip())
return out
def _cap_line(s: str) -> str:
return s if len(s) <= _EXCERPT_LINE_CHARS else s[: _EXCERPT_LINE_CHARS] + "…"
def _bm25_match_idx(query: str, lines: list[str]) -> int:
q_tokens = set(_EXCERPT_TOKEN_RE.findall(query.lower()))
if not q_tokens:
return 0
best_i, best_count = 0, -1
for i, line in enumerate(lines):
tokens = _EXCERPT_TOKEN_RE.findall(line.lower())
count = sum(1 for t in tokens if t in q_tokens)
if count > best_count:
best_i, best_count = i, count
return best_i
def _semantic_match_idx(query: str, lines: list[str]) -> int:
# Reuse the model already loaded by the semantic-text searcher so we don't
# double the VRAM/load time. This is a private import on purpose.
from searchers.semantic_text import _get_model
import numpy as np
model = _get_model()
q_emb = model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
line_emb = model.encode(
lines,
convert_to_numpy=True,
normalize_embeddings=True,
batch_size=64,
show_progress_bar=False,
)
sims = (line_emb @ q_emb.T).flatten()
return int(np.argmax(sims))
def _text_excerpt(modality: str, query: str, requested: str) -> dict | None:
"""Three-part snippet ({before, match, after}) centered on the line most
relevant to the query for this modality."""
resolved = _resolve_path(requested)
if resolved is None:
return None
try:
raw = resolved.read_text(encoding="utf-8", errors="replace")
except OSError:
return None
candidates = _split_candidates(raw)
if not candidates:
return None
if modality == "semantic-text":
center = _semantic_match_idx(query, candidates)
else:
center = _bm25_match_idx(query, candidates)
start = max(0, center - _EXCERPT_CONTEXT_LINES)
end = min(len(candidates), center + _EXCERPT_CONTEXT_LINES + 1)
return {
"before": "\n".join(_cap_line(c) for c in candidates[start:center]),
"match": _cap_line(candidates[center]),
"after": "\n".join(_cap_line(c) for c in candidates[center + 1 : end]),
}
def _resolve_path(requested: str) -> Path | None:
p = Path(requested)
if p.is_absolute() and p.is_file():
return p
# Indexer stores paths relative to the original source folder. Try:
# 1. <INDEX_DATA_DIR>/source/<rel> -- copy pulled from the bucket
# 2. <_source.txt>/<rel> -- original folder on this machine
if not p.is_absolute():
bucket_copy = INDEX_DATA_DIR / "source" / p
if bucket_copy.is_file():
return bucket_copy
src = _source_root()
if src is not None:
local_copy = src / p
if local_copy.is_file():
return local_copy
# Last-resort: basename match anywhere under the repo, for indexes that
# still hold stale absolute paths or whose source folder has moved.
global _BASENAME_INDEX
if _BASENAME_INDEX is None:
_BASENAME_INDEX = _build_basename_index()
hit = _BASENAME_INDEX.get(p.name)
return Path(hit) if hit else None
class Handler(BaseHTTPRequestHandler):
def handle(self):
try:
super().handle()
except BrokenPipeError:
pass
def log_message(self, fmt, *args):
sys.stderr.write(f"[backend] {self.address_string()} - {fmt % args}\n")
def _send_json(self, status: int, payload):
body = json.dumps(payload).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(body)
def _send_file(self, path: Path, status: int = 200):
if not path.is_file():
self._send_json(404, {"error": f"not found: {path}"})
return
ctype, _ = mimetypes.guess_type(str(path))
ctype = ctype or "application/octet-stream"
size = path.stat().st_size
range_hdr = self.headers.get("Range")
if range_hdr and range_hdr.startswith("bytes="):
# Parse byte range (only single ranges supported).
spec = range_hdr[len("bytes="):]
start_s, _, end_s = spec.partition("-")
start = int(start_s) if start_s else 0
end = int(end_s) if end_s else size - 1
end = min(end, size - 1)
length = end - start + 1
with open(path, "rb") as f:
f.seek(start)
data = f.read(length)
self.send_response(206)
self.send_header("Content-Type", ctype)
self.send_header("Content-Range", f"bytes {start}-{end}/{size}")
self.send_header("Content-Length", str(length))
self.send_header("Accept-Ranges", "bytes")
self.end_headers()
self.wfile.write(data)
else:
data = path.read_bytes()
self.send_response(status)
self.send_header("Content-Type", ctype)
self.send_header("Content-Length", str(len(data)))
self.send_header("Accept-Ranges", "bytes")
self.end_headers()
self.wfile.write(data)
def _run_search(self, modality: str, query: str, top_k: int):
fn = SEARCHERS.get(modality)
if fn is None:
self._send_json(404, {"error": f"unknown modality: {modality}"})
return
kind = _modality_kind(modality)
try:
hits = fn(query, top_k)
except Exception as exc:
self._send_json(
500,
{
"modality": modality,
"kind": kind,
"query": query,
"error": str(exc),
"hits": [],
},
)
return
out = []
for s, p in hits:
entry = {"score": float(s), "path": p}
if kind == "text":
excerpt = _text_excerpt(modality, query, p)
if excerpt:
entry["excerpt"] = excerpt
out.append(entry)
self._send_json(
200,
{
"modality": modality,
"kind": kind,
"query": query,
"hits": out,
},
)
def do_GET(self):
parsed = urllib.parse.urlparse(self.path)
route = parsed.path
qs = urllib.parse.parse_qs(parsed.query)
# static UI
if route in ("/", "/index.html"):
self._send_file(UI_STATIC_DIR / "index.html")
return
if route.startswith("/static/"):
target = (UI_STATIC_DIR / route[len("/static/") :]).resolve()
root = UI_STATIC_DIR.resolve()
if root != target and root not in target.parents:
self._send_json(403, {"error": "forbidden"})
return
self._send_file(target)
return
# API
if route == "/api/modalities":
self._send_json(
200,
{
"modalities": sorted(SEARCHERS.keys()),
"groups": {k: list(v) for k, v in GROUPS.items()},
},
)
return
if route.startswith("/api/search/"):
modality = route[len("/api/search/") :]
query = (qs.get("q", [""])[0] or "").strip()
if not query:
self._send_json(400, {"error": "missing q"})
return
top_k = int(qs.get("k", ["8"])[0])
self._run_search(modality, query, top_k)
return
if route == "/api/file":
path = qs.get("path", [""])[0]
if not path:
self._send_json(400, {"error": "missing path"})
return
global ALLOWED_PATHS
if path not in ALLOWED_PATHS:
ALLOWED_PATHS = _load_allowed_paths()
if path not in ALLOWED_PATHS:
self._send_json(403, {"error": "path not in any index"})
return
resolved = _resolve_path(path)
if resolved is None:
self._send_json(404, {"error": f"file missing on disk: {path}"})
return
self._send_file(resolved)
return
self._send_json(404, {"error": f"no route for {route}"})
def _sync_loop() -> None:
"""Pull the shared index bucket into index_data/ every SYNC_INTERVAL_S."""
# `sync` lives in INDEXING_DIR which is already on sys.path.
import sync as sync_mod
while True:
try:
sync_mod.sync(pull=True)
except Exception as exc:
print(f"[sync] pull failed: {exc}", file=sys.stderr)
time.sleep(SYNC_INTERVAL_S)
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__.strip())
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
threading.Thread(target=_sync_loop, daemon=True).start()
server = ThreadingHTTPServer((args.host, args.port), Handler)
print(f"ragstudio backend on http://{args.host}:{args.port}")
print(f" cwd: {os.getcwd()}")
print(f" searchers: {sorted(SEARCHERS)}")
print(f" static UI: {UI_STATIC_DIR}")
print(f" index files: {len(ALLOWED_PATHS)}")
print(f" sync: pull every {SYNC_INTERVAL_S}s")
try:
server.serve_forever()
except KeyboardInterrupt:
print("\nshutting down")
server.server_close()
if __name__ == "__main__":
main()