File size: 3,205 Bytes
3e600b5
fed0900
 
3e600b5
f06c1f2
93150d4
 
 
fed0900
 
 
 
f06c1f2
3e600b5
f06c1f2
fed0900
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e600b5
fed0900
 
 
 
 
 
 
 
 
 
 
 
3e600b5
 
f06c1f2
 
 
fed0900
f06c1f2
fed0900
 
 
f06c1f2
 
fed0900
 
 
 
 
 
 
 
 
 
 
f06c1f2
fed0900
f06c1f2
 
 
 
 
3e600b5
 
b5d27a5
f06c1f2
 
 
 
 
 
fed0900
f06c1f2
 
 
 
b5d27a5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
CLI client for the ragstudio backend. Talks HTTP to the long-running server so
model weights stay loaded in one process.

Usage:
    python query.py "your query" [top_k] [-m semantic-text|bm25|image|text|...]

`text` is a group alias expanding to `semantic-text` + `bm25`.

Set RAGSTUDIO_URL to point at a non-default backend (default
`http://127.0.0.1:8000`). If the backend is not reachable, the CLI exits with
a clear instruction to start it.
"""

import argparse
import json
import os
import sys
import urllib.error
import urllib.parse
import urllib.request

DEFAULT_URL = os.environ.get("RAGSTUDIO_URL", "http://127.0.0.1:8000")


def _http_json(url: str) -> dict:
    try:
        with urllib.request.urlopen(url, timeout=120) as r:
            return json.loads(r.read().decode("utf-8"))
    except urllib.error.HTTPError as e:
        body = e.read().decode("utf-8", "replace")
        try:
            return json.loads(body)
        except Exception:
            raise SystemExit(f"backend error {e.code}: {body}") from None
    except urllib.error.URLError as e:
        raise SystemExit(
            f"backend not reachable at {DEFAULT_URL} ({e.reason}).\n"
            f"Start it with: python backend/server.py"
        ) from None


def _modalities() -> tuple[list[str], dict[str, list[str]]]:
    data = _http_json(f"{DEFAULT_URL}/api/modalities")
    return data["modalities"], data.get("groups", {})


def _expand(modalities, groups):
    out: list[str] = []
    for m in modalities:
        for name in groups.get(m, (m,)):
            if name not in out:
                out.append(name)
    return out


def search(
    query: str,
    top_k: int = 5,
    modalities=None,
) -> dict[str, list[tuple[float, str]]]:
    all_mods, groups = _modalities()
    selected = _expand(modalities, groups) if modalities is not None else list(all_mods)

    results: dict[str, list[tuple[float, str]]] = {}
    for name in selected:
        if name not in all_mods:
            print(f"  ! unknown modality: {name} (known: {sorted(all_mods)})")
            continue
        url = (
            f"{DEFAULT_URL}/api/search/{urllib.parse.quote(name)}?"
            + urllib.parse.urlencode({"q": query, "k": top_k})
        )
        data = _http_json(url)
        if "error" in data and not data.get("hits"):
            print(f"\n=== {name}: error ===\n  {data['error']}")
            results[name] = []
            continue
        hits = [(float(h["score"]), h["path"]) for h in data.get("hits", [])]
        results[name] = hits
        print(f"\n=== Top {len(hits)} {name} matches ===")
        for score, item in hits:
            print(f"  {score:.3f}  {item}")
    return results


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__.strip())
    parser.add_argument("query")
    parser.add_argument("top_k", nargs="?", type=int, default=5)
    parser.add_argument(
        "-m",
        "--modalities",
        help="Comma-separated modalities or group names (default: all).",
    )
    args = parser.parse_args()
    mods = args.modalities.split(",") if args.modalities else None
    search(args.query, args.top_k, mods)


if __name__ == "__main__":
    main()