shiqiangwang commited on
Commit
46a4da8
·
1 Parent(s): f447477

Initial commit

Browse files
Files changed (10) hide show
  1. .gitignore +5 -0
  2. .python-version +1 -0
  3. README.md +11 -1
  4. build_json_rag.py +67 -0
  5. data/.keep +0 -0
  6. main.py +16 -0
  7. pyproject.toml +16 -0
  8. tools/rag_search.py +51 -0
  9. utils/tool_chat.py +76 -0
  10. uv.lock +0 -0
.gitignore CHANGED
@@ -205,3 +205,8 @@ cython_debug/
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
 
 
 
 
 
 
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
208
+
209
+
210
+ data/*.json
211
+ .vscode/
212
+ rag_index_json/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.14
README.md CHANGED
@@ -1 +1,11 @@
1
- # llm-agent
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm-agents
2
+
3
+ Install Ollama and pull `nomic-embed-text` and `llama3.1` models
4
+
5
+ ### Agent for NeurIPS 2025 papers
6
+
7
+ Download https://neurips.cc/static/virtual/data/neurips-2025-orals-posters.json and put the file in the `./data` folder.
8
+
9
+ Run `uv run build_json_rag.py` to build the vector database for RAG.
10
+
11
+ Then run using `uv run main.py`.
build_json_rag.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag_build_json.py
2
+ from __future__ import annotations
3
+ from dataclasses import dataclass
4
+ import time
5
+ from typing import List, Dict, Any, Tuple
6
+ import json, os, math
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import faiss
11
+ from litellm import embedding
12
+ from tqdm import tqdm
13
+
14
+ API_BASE = "http://localhost:11434"
15
+ EMBED_MODEL = "ollama/nomic-embed-text" # pull with: ollama pull nomic-embed-text
16
+ INDEX_DIR = "rag_index_json"
17
+
18
+
19
+ def _load_papers(json_path: str) -> List[Paper]:
20
+ data = json.loads(Path(json_path).read_text())
21
+ papers = data["results"]
22
+ return papers
23
+
24
+ def _embed(texts: List[str], batch: int = 1) -> np.ndarray:
25
+ vecs: List[List[float]] = []
26
+ for i in tqdm(range(0, len(texts), batch)):
27
+ chunk = texts[i:i+batch]
28
+ retry = 0
29
+ max_retries = 100
30
+ while retry < max_retries:
31
+ try:
32
+ resp = embedding(model=EMBED_MODEL, input=chunk, api_base=API_BASE)
33
+ break
34
+ except Exception as e:
35
+ print(f"Error during embedding batch {i}-{i+batch}: {e}")
36
+ # print(chunk)
37
+ retry += 1
38
+ if retry >= max_retries:
39
+ raise Exception("Max retries reached for embedding.")
40
+ print(f"Retrying {retry} ...")
41
+ time.sleep(0.1)
42
+
43
+ vecs.extend([d["embedding"] for d in resp["data"]])
44
+ # time.sleep(5) # avoid rate limit
45
+ arr = np.array(vecs, dtype="float32")
46
+ # cosine similarity: normalize to unit length and use IndexFlatIP
47
+ norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
48
+ return arr / norms
49
+
50
+ def build_index(json_path: str, out_dir: str = INDEX_DIR) -> None:
51
+ out = Path(out_dir); out.mkdir(parents=True, exist_ok=True)
52
+ papers = _load_papers(json_path)
53
+
54
+ texts = [f"{p['name']}\n\n{p['abstract']}" if "name" in p and "abstract" in p else "" for p in papers]
55
+ X = _embed(texts)
56
+ dim = X.shape[1]
57
+ index = faiss.IndexFlatIP(dim)
58
+ index.add(X)
59
+
60
+ faiss.write_index(index, str(out / "faiss.index"))
61
+ # metadata alongside embeddings
62
+ (out / "meta.json").write_text(json.dumps(papers, ensure_ascii=False, indent=4, sort_keys=True))
63
+ print(f"Indexed {len(papers)} papers → {out.resolve()}")
64
+
65
+ if __name__ == "__main__":
66
+ # build_index("data/papers_test.json")
67
+ build_index("data/neurips-2025-orals-posters-pretty.json")
data/.keep ADDED
File without changes
main.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tools.rag_search import search_papers # <- the tool we expose
2
+ from utils.tool_chat import ToolChat
3
+ import sys
4
+
5
+ if __name__ == "__main__":
6
+ chat = ToolChat()
7
+ system = {"role": "system", "content": "You are a research assistant. Use the search tool when helpful, but you may want to search for a larger number of relevant papers, then select from those that are most relevant to the user's query by further examining the titles and abstracts of the papers found. Cite titles, abstracts, and OpenReview URLs in your answers."}
8
+ user = {"role": "user", "content": "Find papers on large language models. Give me the title, abstract, and OpenReview URLs."}
9
+
10
+ if len(sys.argv) > 1:
11
+ user["content"] = " ".join(sys.argv[1:])
12
+ else:
13
+ user["content"] = input("Enter user prompt: ")
14
+
15
+ resp = chat.tool_loop([system, user], registry={"search_papers": search_papers})
16
+ print(resp["choices"][0]["message"]["content"])
pyproject.toml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "llm-agents"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.14"
7
+ dependencies = [
8
+ "faiss-cpu>=1.12.0",
9
+ "httpx>=0.28.1",
10
+ "litellm>=1.79.0",
11
+ "pydantic>=2.12.3",
12
+ "pydantic-settings>=2.11.0",
13
+ "rich>=14.2.0",
14
+ "sentence-transformers>=5.1.2",
15
+ "typer>=0.20.0",
16
+ ]
tools/rag_search.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import List, Dict, Any
3
+ import json
4
+ from pathlib import Path
5
+ import numpy as np
6
+ import faiss
7
+ from litellm import embedding
8
+ import time
9
+
10
+ API_BASE = "http://localhost:11434"
11
+ EMBED_MODEL = "ollama/nomic-embed-text"
12
+ INDEX_DIR = "rag_index_json"
13
+
14
+ # Load once at import (fast)
15
+ _FAISS = faiss.read_index(str(Path(INDEX_DIR) / "faiss.index"))
16
+ _META = json.loads(Path(INDEX_DIR, "meta.json").read_text())
17
+
18
+ def _embed_query(q: str) -> np.ndarray:
19
+ resp = embedding(model=EMBED_MODEL, input=[q], api_base=API_BASE)
20
+ v = np.array(resp["data"][0]["embedding"], dtype="float32")
21
+ v = v / (np.linalg.norm(v) + 1e-12)
22
+ return v.reshape(1, -1)
23
+
24
+ def _snippet(text: str, max_chars: int = 240) -> str:
25
+ text = " ".join(text.split())
26
+ return text if len(text) <= max_chars else text[:max_chars-1] + "…"
27
+
28
+ def search_papers(query: str, k: int = 10) -> Dict[str, Any]:
29
+ """
30
+ Search top-k relevant papers over title+abstract. You need to specify k as the total number of results that you would like to get.
31
+ Returns JSON with 'query' and 'results' (the results include metadata of the paper).
32
+ """
33
+ try:
34
+ k = int(k)
35
+ except Exception:
36
+ k = 5
37
+
38
+ start_time = time.time()
39
+ qv = _embed_query(query)
40
+ scores, idxs = _FAISS.search(qv, k)
41
+ scores, idxs = scores[0].tolist(), idxs[0].tolist()
42
+ print(f"Search for '{query}' took {time.time() - start_time:.3f}s")
43
+
44
+ results = []
45
+ for score, i in zip(scores, idxs):
46
+ if i == -1:
47
+ continue
48
+ m = _META[i]
49
+ text = f"{m['name']}\n\n{m['abstract']}"
50
+ results.append(m)
51
+ return {"query": query, "results": results}
utils/tool_chat.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Any, Dict, List, Callable, get_args, get_origin, Literal, Annotated, Union
3
+ from dataclasses import dataclass, field
4
+ import json, inspect
5
+ from litellm import completion
6
+
7
+ from tools.rag_search import search_papers # <- the tool we expose
8
+
9
+ # ---- Minimal auto-schema from function signature ----
10
+ def _json_type(t: Any) -> Dict[str, Any]:
11
+ origin, args = get_origin(t), get_args(t)
12
+ if origin is Literal:
13
+ return {"type": "string", "enum": list(args)}
14
+ if origin in (list, List):
15
+ return {"type": "array", "items": {"type": "string"}}
16
+ if t in (str,): return {"type": "string"}
17
+ if t in (int,): return {"type": "integer"}
18
+ if t in (float,): return {"type": "number"}
19
+ if t in (bool,): return {"type": "boolean"}
20
+ return {"type": "string"}
21
+
22
+ def infer_tool(func: Callable[..., Any]) -> Dict[str, Any]:
23
+ sig = inspect.signature(func)
24
+ hints = getattr(func, "__annotations__", {})
25
+ props, required = {}, []
26
+ for name, p in sig.parameters.items():
27
+ if name in ("self", "cls"): continue
28
+ schema = _json_type(hints.get(name, str))
29
+ if p.default is inspect._empty: required.append(name)
30
+ props[name] = schema
31
+ return {
32
+ "type": "function",
33
+ "function": {
34
+ "name": func.__name__,
35
+ "description": (inspect.getdoc(func) or f"Call {func.__name__}"),
36
+ "parameters": {"type": "object", "properties": props, "required": required},
37
+ },
38
+ }
39
+
40
+ # ---- LiteLLM chat wrapper with tool loop ----
41
+ @dataclass
42
+ class ToolChat:
43
+ model: str = "ollama_chat/llama3.1"
44
+ # model: str = "ollama_chat/ibm/granite4:350m"
45
+ api_base: str = "http://localhost:11434"
46
+ default_params: Dict[str, Any] = field(default_factory=lambda: {"temperature": 0.2, "max_tokens": 2000})
47
+
48
+ def tool_loop(self, messages: List[Dict[str, Any]], registry: Dict[str, Callable[..., Any]], max_rounds: int = 3) -> Dict[str, Any]:
49
+ tools = [infer_tool(fn) for fn in registry.values()]
50
+ msgs = list(messages)
51
+ for _ in range(max_rounds):
52
+ resp = completion(model=self.model, messages=msgs, tools=tools, tool_choice="auto", api_base=self.api_base, **self.default_params)
53
+
54
+ msg = resp["choices"][0].get("message", {})
55
+ calls = msg.get("tool_calls") or []
56
+ print("tool_calls:", calls)
57
+
58
+ if not calls:
59
+ return resp
60
+
61
+ # execute tools and append results
62
+ for call in calls:
63
+ name = call["function"]["name"]
64
+ args = call["function"].get("arguments", "{}")
65
+ try:
66
+ parsed = json.loads(args) if isinstance(args, str) else (args or {})
67
+ except json.JSONDecodeError:
68
+ parsed = {}
69
+ out = registry[name](**parsed)
70
+ msgs.append({
71
+ "role": "tool",
72
+ "tool_call_id": call.get("id"),
73
+ "name": name,
74
+ "content": json.dumps(out, ensure_ascii=False),
75
+ })
76
+ return completion(model=self.model, messages=msgs, api_base=self.api_base, **self.default_params)
uv.lock ADDED
The diff for this file is too large to render. See raw diff