Spaces:
Build error
Build error
Commit ·
46a4da8
1
Parent(s): f447477
Initial commit
Browse files- .gitignore +5 -0
- .python-version +1 -0
- README.md +11 -1
- build_json_rag.py +67 -0
- data/.keep +0 -0
- main.py +16 -0
- pyproject.toml +16 -0
- tools/rag_search.py +51 -0
- utils/tool_chat.py +76 -0
- uv.lock +0 -0
.gitignore
CHANGED
|
@@ -205,3 +205,8 @@ cython_debug/
|
|
| 205 |
marimo/_static/
|
| 206 |
marimo/_lsp/
|
| 207 |
__marimo__/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
marimo/_static/
|
| 206 |
marimo/_lsp/
|
| 207 |
__marimo__/
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
data/*.json
|
| 211 |
+
.vscode/
|
| 212 |
+
rag_index_json/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.14
|
README.md
CHANGED
|
@@ -1 +1,11 @@
|
|
| 1 |
-
# llm-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# llm-agents
|
| 2 |
+
|
| 3 |
+
Install Ollama and pull `nomic-embed-text` and `llama3.1` models
|
| 4 |
+
|
| 5 |
+
### Agent for NeurIPS 2025 papers
|
| 6 |
+
|
| 7 |
+
Download https://neurips.cc/static/virtual/data/neurips-2025-orals-posters.json and put the file in the `./data` folder.
|
| 8 |
+
|
| 9 |
+
Run `uv run build_json_rag.py` to build the vector database for RAG.
|
| 10 |
+
|
| 11 |
+
Then run using `uv run main.py`.
|
build_json_rag.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rag_build_json.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
import time
|
| 5 |
+
from typing import List, Dict, Any, Tuple
|
| 6 |
+
import json, os, math
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import faiss
|
| 11 |
+
from litellm import embedding
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
API_BASE = "http://localhost:11434"
|
| 15 |
+
EMBED_MODEL = "ollama/nomic-embed-text" # pull with: ollama pull nomic-embed-text
|
| 16 |
+
INDEX_DIR = "rag_index_json"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _load_papers(json_path: str) -> List[Paper]:
|
| 20 |
+
data = json.loads(Path(json_path).read_text())
|
| 21 |
+
papers = data["results"]
|
| 22 |
+
return papers
|
| 23 |
+
|
| 24 |
+
def _embed(texts: List[str], batch: int = 1) -> np.ndarray:
|
| 25 |
+
vecs: List[List[float]] = []
|
| 26 |
+
for i in tqdm(range(0, len(texts), batch)):
|
| 27 |
+
chunk = texts[i:i+batch]
|
| 28 |
+
retry = 0
|
| 29 |
+
max_retries = 100
|
| 30 |
+
while retry < max_retries:
|
| 31 |
+
try:
|
| 32 |
+
resp = embedding(model=EMBED_MODEL, input=chunk, api_base=API_BASE)
|
| 33 |
+
break
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(f"Error during embedding batch {i}-{i+batch}: {e}")
|
| 36 |
+
# print(chunk)
|
| 37 |
+
retry += 1
|
| 38 |
+
if retry >= max_retries:
|
| 39 |
+
raise Exception("Max retries reached for embedding.")
|
| 40 |
+
print(f"Retrying {retry} ...")
|
| 41 |
+
time.sleep(0.1)
|
| 42 |
+
|
| 43 |
+
vecs.extend([d["embedding"] for d in resp["data"]])
|
| 44 |
+
# time.sleep(5) # avoid rate limit
|
| 45 |
+
arr = np.array(vecs, dtype="float32")
|
| 46 |
+
# cosine similarity: normalize to unit length and use IndexFlatIP
|
| 47 |
+
norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
|
| 48 |
+
return arr / norms
|
| 49 |
+
|
| 50 |
+
def build_index(json_path: str, out_dir: str = INDEX_DIR) -> None:
|
| 51 |
+
out = Path(out_dir); out.mkdir(parents=True, exist_ok=True)
|
| 52 |
+
papers = _load_papers(json_path)
|
| 53 |
+
|
| 54 |
+
texts = [f"{p['name']}\n\n{p['abstract']}" if "name" in p and "abstract" in p else "" for p in papers]
|
| 55 |
+
X = _embed(texts)
|
| 56 |
+
dim = X.shape[1]
|
| 57 |
+
index = faiss.IndexFlatIP(dim)
|
| 58 |
+
index.add(X)
|
| 59 |
+
|
| 60 |
+
faiss.write_index(index, str(out / "faiss.index"))
|
| 61 |
+
# metadata alongside embeddings
|
| 62 |
+
(out / "meta.json").write_text(json.dumps(papers, ensure_ascii=False, indent=4, sort_keys=True))
|
| 63 |
+
print(f"Indexed {len(papers)} papers → {out.resolve()}")
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
# build_index("data/papers_test.json")
|
| 67 |
+
build_index("data/neurips-2025-orals-posters-pretty.json")
|
data/.keep
ADDED
|
File without changes
|
main.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tools.rag_search import search_papers # <- the tool we expose
|
| 2 |
+
from utils.tool_chat import ToolChat
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
chat = ToolChat()
|
| 7 |
+
system = {"role": "system", "content": "You are a research assistant. Use the search tool when helpful, but you may want to search for a larger number of relevant papers, then select from those that are most relevant to the user's query by further examining the titles and abstracts of the papers found. Cite titles, abstracts, and OpenReview URLs in your answers."}
|
| 8 |
+
user = {"role": "user", "content": "Find papers on large language models. Give me the title, abstract, and OpenReview URLs."}
|
| 9 |
+
|
| 10 |
+
if len(sys.argv) > 1:
|
| 11 |
+
user["content"] = " ".join(sys.argv[1:])
|
| 12 |
+
else:
|
| 13 |
+
user["content"] = input("Enter user prompt: ")
|
| 14 |
+
|
| 15 |
+
resp = chat.tool_loop([system, user], registry={"search_papers": search_papers})
|
| 16 |
+
print(resp["choices"][0]["message"]["content"])
|
pyproject.toml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "llm-agents"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.14"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"faiss-cpu>=1.12.0",
|
| 9 |
+
"httpx>=0.28.1",
|
| 10 |
+
"litellm>=1.79.0",
|
| 11 |
+
"pydantic>=2.12.3",
|
| 12 |
+
"pydantic-settings>=2.11.0",
|
| 13 |
+
"rich>=14.2.0",
|
| 14 |
+
"sentence-transformers>=5.1.2",
|
| 15 |
+
"typer>=0.20.0",
|
| 16 |
+
]
|
tools/rag_search.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import numpy as np
|
| 6 |
+
import faiss
|
| 7 |
+
from litellm import embedding
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
API_BASE = "http://localhost:11434"
|
| 11 |
+
EMBED_MODEL = "ollama/nomic-embed-text"
|
| 12 |
+
INDEX_DIR = "rag_index_json"
|
| 13 |
+
|
| 14 |
+
# Load once at import (fast)
|
| 15 |
+
_FAISS = faiss.read_index(str(Path(INDEX_DIR) / "faiss.index"))
|
| 16 |
+
_META = json.loads(Path(INDEX_DIR, "meta.json").read_text())
|
| 17 |
+
|
| 18 |
+
def _embed_query(q: str) -> np.ndarray:
|
| 19 |
+
resp = embedding(model=EMBED_MODEL, input=[q], api_base=API_BASE)
|
| 20 |
+
v = np.array(resp["data"][0]["embedding"], dtype="float32")
|
| 21 |
+
v = v / (np.linalg.norm(v) + 1e-12)
|
| 22 |
+
return v.reshape(1, -1)
|
| 23 |
+
|
| 24 |
+
def _snippet(text: str, max_chars: int = 240) -> str:
|
| 25 |
+
text = " ".join(text.split())
|
| 26 |
+
return text if len(text) <= max_chars else text[:max_chars-1] + "…"
|
| 27 |
+
|
| 28 |
+
def search_papers(query: str, k: int = 10) -> Dict[str, Any]:
|
| 29 |
+
"""
|
| 30 |
+
Search top-k relevant papers over title+abstract. You need to specify k as the total number of results that you would like to get.
|
| 31 |
+
Returns JSON with 'query' and 'results' (the results include metadata of the paper).
|
| 32 |
+
"""
|
| 33 |
+
try:
|
| 34 |
+
k = int(k)
|
| 35 |
+
except Exception:
|
| 36 |
+
k = 5
|
| 37 |
+
|
| 38 |
+
start_time = time.time()
|
| 39 |
+
qv = _embed_query(query)
|
| 40 |
+
scores, idxs = _FAISS.search(qv, k)
|
| 41 |
+
scores, idxs = scores[0].tolist(), idxs[0].tolist()
|
| 42 |
+
print(f"Search for '{query}' took {time.time() - start_time:.3f}s")
|
| 43 |
+
|
| 44 |
+
results = []
|
| 45 |
+
for score, i in zip(scores, idxs):
|
| 46 |
+
if i == -1:
|
| 47 |
+
continue
|
| 48 |
+
m = _META[i]
|
| 49 |
+
text = f"{m['name']}\n\n{m['abstract']}"
|
| 50 |
+
results.append(m)
|
| 51 |
+
return {"query": query, "results": results}
|
utils/tool_chat.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import Any, Dict, List, Callable, get_args, get_origin, Literal, Annotated, Union
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
import json, inspect
|
| 5 |
+
from litellm import completion
|
| 6 |
+
|
| 7 |
+
from tools.rag_search import search_papers # <- the tool we expose
|
| 8 |
+
|
| 9 |
+
# ---- Minimal auto-schema from function signature ----
|
| 10 |
+
def _json_type(t: Any) -> Dict[str, Any]:
|
| 11 |
+
origin, args = get_origin(t), get_args(t)
|
| 12 |
+
if origin is Literal:
|
| 13 |
+
return {"type": "string", "enum": list(args)}
|
| 14 |
+
if origin in (list, List):
|
| 15 |
+
return {"type": "array", "items": {"type": "string"}}
|
| 16 |
+
if t in (str,): return {"type": "string"}
|
| 17 |
+
if t in (int,): return {"type": "integer"}
|
| 18 |
+
if t in (float,): return {"type": "number"}
|
| 19 |
+
if t in (bool,): return {"type": "boolean"}
|
| 20 |
+
return {"type": "string"}
|
| 21 |
+
|
| 22 |
+
def infer_tool(func: Callable[..., Any]) -> Dict[str, Any]:
|
| 23 |
+
sig = inspect.signature(func)
|
| 24 |
+
hints = getattr(func, "__annotations__", {})
|
| 25 |
+
props, required = {}, []
|
| 26 |
+
for name, p in sig.parameters.items():
|
| 27 |
+
if name in ("self", "cls"): continue
|
| 28 |
+
schema = _json_type(hints.get(name, str))
|
| 29 |
+
if p.default is inspect._empty: required.append(name)
|
| 30 |
+
props[name] = schema
|
| 31 |
+
return {
|
| 32 |
+
"type": "function",
|
| 33 |
+
"function": {
|
| 34 |
+
"name": func.__name__,
|
| 35 |
+
"description": (inspect.getdoc(func) or f"Call {func.__name__}"),
|
| 36 |
+
"parameters": {"type": "object", "properties": props, "required": required},
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# ---- LiteLLM chat wrapper with tool loop ----
|
| 41 |
+
@dataclass
|
| 42 |
+
class ToolChat:
|
| 43 |
+
model: str = "ollama_chat/llama3.1"
|
| 44 |
+
# model: str = "ollama_chat/ibm/granite4:350m"
|
| 45 |
+
api_base: str = "http://localhost:11434"
|
| 46 |
+
default_params: Dict[str, Any] = field(default_factory=lambda: {"temperature": 0.2, "max_tokens": 2000})
|
| 47 |
+
|
| 48 |
+
def tool_loop(self, messages: List[Dict[str, Any]], registry: Dict[str, Callable[..., Any]], max_rounds: int = 3) -> Dict[str, Any]:
|
| 49 |
+
tools = [infer_tool(fn) for fn in registry.values()]
|
| 50 |
+
msgs = list(messages)
|
| 51 |
+
for _ in range(max_rounds):
|
| 52 |
+
resp = completion(model=self.model, messages=msgs, tools=tools, tool_choice="auto", api_base=self.api_base, **self.default_params)
|
| 53 |
+
|
| 54 |
+
msg = resp["choices"][0].get("message", {})
|
| 55 |
+
calls = msg.get("tool_calls") or []
|
| 56 |
+
print("tool_calls:", calls)
|
| 57 |
+
|
| 58 |
+
if not calls:
|
| 59 |
+
return resp
|
| 60 |
+
|
| 61 |
+
# execute tools and append results
|
| 62 |
+
for call in calls:
|
| 63 |
+
name = call["function"]["name"]
|
| 64 |
+
args = call["function"].get("arguments", "{}")
|
| 65 |
+
try:
|
| 66 |
+
parsed = json.loads(args) if isinstance(args, str) else (args or {})
|
| 67 |
+
except json.JSONDecodeError:
|
| 68 |
+
parsed = {}
|
| 69 |
+
out = registry[name](**parsed)
|
| 70 |
+
msgs.append({
|
| 71 |
+
"role": "tool",
|
| 72 |
+
"tool_call_id": call.get("id"),
|
| 73 |
+
"name": name,
|
| 74 |
+
"content": json.dumps(out, ensure_ascii=False),
|
| 75 |
+
})
|
| 76 |
+
return completion(model=self.model, messages=msgs, api_base=self.api_base, **self.default_params)
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|