Spaces:
Sleeping
Sleeping
File size: 3,205 Bytes
3e600b5 fed0900 3e600b5 f06c1f2 93150d4 fed0900 f06c1f2 3e600b5 f06c1f2 fed0900 3e600b5 fed0900 3e600b5 f06c1f2 fed0900 f06c1f2 fed0900 f06c1f2 fed0900 f06c1f2 fed0900 f06c1f2 3e600b5 b5d27a5 f06c1f2 fed0900 f06c1f2 b5d27a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | """
CLI client for the ragstudio backend. Talks HTTP to the long-running server so
model weights stay loaded in one process.
Usage:
python query.py "your query" [top_k] [-m semantic-text|bm25|image|text|...]
`text` is a group alias expanding to `semantic-text` + `bm25`.
Set RAGSTUDIO_URL to point at a non-default backend (default
`http://127.0.0.1:8000`). If the backend is not reachable, the CLI exits with
a clear instruction to start it.
"""
import argparse
import json
import os
import sys
import urllib.error
import urllib.parse
import urllib.request
DEFAULT_URL = os.environ.get("RAGSTUDIO_URL", "http://127.0.0.1:8000")
def _http_json(url: str) -> dict:
try:
with urllib.request.urlopen(url, timeout=120) as r:
return json.loads(r.read().decode("utf-8"))
except urllib.error.HTTPError as e:
body = e.read().decode("utf-8", "replace")
try:
return json.loads(body)
except Exception:
raise SystemExit(f"backend error {e.code}: {body}") from None
except urllib.error.URLError as e:
raise SystemExit(
f"backend not reachable at {DEFAULT_URL} ({e.reason}).\n"
f"Start it with: python backend/server.py"
) from None
def _modalities() -> tuple[list[str], dict[str, list[str]]]:
data = _http_json(f"{DEFAULT_URL}/api/modalities")
return data["modalities"], data.get("groups", {})
def _expand(modalities, groups):
out: list[str] = []
for m in modalities:
for name in groups.get(m, (m,)):
if name not in out:
out.append(name)
return out
def search(
query: str,
top_k: int = 5,
modalities=None,
) -> dict[str, list[tuple[float, str]]]:
all_mods, groups = _modalities()
selected = _expand(modalities, groups) if modalities is not None else list(all_mods)
results: dict[str, list[tuple[float, str]]] = {}
for name in selected:
if name not in all_mods:
print(f" ! unknown modality: {name} (known: {sorted(all_mods)})")
continue
url = (
f"{DEFAULT_URL}/api/search/{urllib.parse.quote(name)}?"
+ urllib.parse.urlencode({"q": query, "k": top_k})
)
data = _http_json(url)
if "error" in data and not data.get("hits"):
print(f"\n=== {name}: error ===\n {data['error']}")
results[name] = []
continue
hits = [(float(h["score"]), h["path"]) for h in data.get("hits", [])]
results[name] = hits
print(f"\n=== Top {len(hits)} {name} matches ===")
for score, item in hits:
print(f" {score:.3f} {item}")
return results
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__.strip())
parser.add_argument("query")
parser.add_argument("top_k", nargs="?", type=int, default=5)
parser.add_argument(
"-m",
"--modalities",
help="Comma-separated modalities or group names (default: all).",
)
args = parser.parse_args()
mods = args.modalities.split(",") if args.modalities else None
search(args.query, args.top_k, mods)
if __name__ == "__main__":
main()
|