VectorDB
Browse files- .gitignore +2 -0
- Makefile +14 -1
- app/core/rag/build.py +300 -0
- app/main.py +49 -25
- app/middleware.py +51 -8
- app/routers/chat.py +30 -19
- app/services/chat_service.py +236 -34
- configs/rag_sources.yaml +41 -0
- data/kb.jsonl +0 -0
- requirements.txt +5 -0
- scripts/build_kb.py +54 -0
.gitignore
CHANGED
|
@@ -32,3 +32,5 @@ Thumbs.db
|
|
| 32 |
# RAG index files
|
| 33 |
.faiss/
|
| 34 |
/backup
|
|
|
|
|
|
|
|
|
| 32 |
# RAG index files
|
| 33 |
.faiss/
|
| 34 |
/backup
|
| 35 |
+
copy *.*
|
| 36 |
+
* copy.*
|
Makefile
CHANGED
|
@@ -55,6 +55,10 @@ help:
|
|
| 55 |
@printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "run" "Run uvicorn (PORT=$(PORT))"
|
| 56 |
@printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "run-hot" "Run with --reload"
|
| 57 |
@echo
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
@echo "$(BRIGHT_GREEN)Docker$(RESET)"
|
| 59 |
@printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "docker-build" "Build local image ($(IMG_NAME))"
|
| 60 |
@printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "docker-run" "Run local container (maps $(PORT))"
|
|
@@ -100,6 +104,15 @@ run: install
|
|
| 100 |
run-hot: install
|
| 101 |
@PORT=$(PORT) $(VENV_DIR)/bin/uvicorn $(APP_MODULE) --host 0.0.0.0 --port $(PORT) --reload
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
# ---------------------------------------------------------------------------
|
| 104 |
# Docker
|
| 105 |
# ---------------------------------------------------------------------------
|
|
@@ -121,4 +134,4 @@ space-url:
|
|
| 121 |
clean:
|
| 122 |
@rm -rf .venv __pycache__ .pytest_cache .ruff_cache .mypy_cache dist build *.egg-info
|
| 123 |
|
| 124 |
-
.PHONY: help venv install lint fmt test run run-hot docker-build docker-run space-url clean
|
|
|
|
| 55 |
@printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "run" "Run uvicorn (PORT=$(PORT))"
|
| 56 |
@printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "run-hot" "Run with --reload"
|
| 57 |
@echo
|
| 58 |
+
@echo "$(BRIGHT_GREEN)RAG / Knowledge Base$(RESET)"
|
| 59 |
+
@printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "kb" "Build/refresh KB from GitHub + local docs (writes data/kb.jsonl)"
|
| 60 |
+
@printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "kb-force" "Force rebuild KB (deletes existing data/kb.jsonl)"
|
| 61 |
+
@echo
|
| 62 |
@echo "$(BRIGHT_GREEN)Docker$(RESET)"
|
| 63 |
@printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "docker-build" "Build local image ($(IMG_NAME))"
|
| 64 |
@printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "docker-run" "Run local container (maps $(PORT))"
|
|
|
|
| 104 |
run-hot: install
|
| 105 |
@PORT=$(PORT) $(VENV_DIR)/bin/uvicorn $(APP_MODULE) --host 0.0.0.0 --port $(PORT) --reload
|
| 106 |
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
# RAG / Knowledge Base
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
kb: install
|
| 111 |
+
@PYTHONPATH=. $(PYTHON) scripts/build_kb.py --config configs/rag_sources.yaml --out data/kb.jsonl
|
| 112 |
+
|
| 113 |
+
kb-force: install
|
| 114 |
+
@rm -f data/kb.jsonl && PYTHONPATH=. $(PYTHON) scripts/build_kb.py --config configs/rag_sources.yaml --out data/kb.jsonl
|
| 115 |
+
|
| 116 |
# ---------------------------------------------------------------------------
|
| 117 |
# Docker
|
| 118 |
# ---------------------------------------------------------------------------
|
|
|
|
| 134 |
clean:
|
| 135 |
@rm -rf .venv __pycache__ .pytest_cache .ruff_cache .mypy_cache dist build *.egg-info
|
| 136 |
|
| 137 |
+
.PHONY: help venv install lint fmt test run run-hot kb kb-force docker-build docker-run space-url clean
|
app/core/rag/build.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import json, os, re, time, math, logging
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Dict, List, Iterable, Tuple, Optional
|
| 5 |
+
|
| 6 |
+
import yaml
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
log = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
# -------------------------
|
| 12 |
+
# Text cleaning & chunking
|
| 13 |
+
# -------------------------
|
| 14 |
+
|
| 15 |
+
_MD_FRONTMATTER = re.compile(r"^---\s*\n.*?\n---\s*\n", re.DOTALL)
|
| 16 |
+
|
| 17 |
+
def normalize_text(text: str) -> str:
|
| 18 |
+
lines = [ln.strip() for ln in text.splitlines()]
|
| 19 |
+
cleaned = []
|
| 20 |
+
for ln in lines:
|
| 21 |
+
if not ln:
|
| 22 |
+
continue
|
| 23 |
+
if sum(ch.isalnum() for ch in ln) < 3:
|
| 24 |
+
continue
|
| 25 |
+
cleaned.append(ln)
|
| 26 |
+
s = "\n".join(cleaned)
|
| 27 |
+
s = re.sub(r"\n{3,}", "\n\n", s)
|
| 28 |
+
return s.strip()
|
| 29 |
+
|
| 30 |
+
def md_to_text(md: str) -> str:
|
| 31 |
+
md = re.sub(_MD_FRONTMATTER, "", md)
|
| 32 |
+
md = re.sub(r"```.*?```", "", md, flags=re.DOTALL) # drop fenced code
|
| 33 |
+
md = re.sub(r"!\[[^\]]*\]\([^)]+\)", "", md) # drop images
|
| 34 |
+
md = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md) # links -> label
|
| 35 |
+
md = re.sub(r"^\s{0,3}#{1,6}\s*", "", md, flags=re.MULTILINE)
|
| 36 |
+
md = md.replace("`", "")
|
| 37 |
+
md = re.sub(r"^\s*[-*+]\s+", "• ", md, flags=re.MULTILINE)
|
| 38 |
+
md = re.sub(r"^\s*>\s?", "", md, flags=re.MULTILINE)
|
| 39 |
+
return normalize_text(md)
|
| 40 |
+
|
| 41 |
+
def chunk_text(text: str, max_chars: int = 800, overlap: int = 120) -> List[str]:
|
| 42 |
+
paras = [p.strip() for p in text.split("\n\n") if p.strip()]
|
| 43 |
+
out: List[str] = []
|
| 44 |
+
buf = ""
|
| 45 |
+
for p in paras:
|
| 46 |
+
if len(p) > max_chars:
|
| 47 |
+
i = 0
|
| 48 |
+
while i < len(p):
|
| 49 |
+
j = min(i + max_chars, len(p))
|
| 50 |
+
out.append(p[i:j])
|
| 51 |
+
i = j - overlap if j - overlap > i else j
|
| 52 |
+
continue
|
| 53 |
+
if len(buf) + 2 + len(p) <= max_chars:
|
| 54 |
+
buf = (buf + "\n\n" + p) if buf else p
|
| 55 |
+
else:
|
| 56 |
+
if buf:
|
| 57 |
+
out.append(buf)
|
| 58 |
+
buf = p
|
| 59 |
+
if buf:
|
| 60 |
+
out.append(buf)
|
| 61 |
+
return out
|
| 62 |
+
|
| 63 |
+
def write_jsonl(records: Iterable[Dict], out_path: Path) -> None:
|
| 64 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
with out_path.open("w", encoding="utf-8") as f:
|
| 66 |
+
for rec in records:
|
| 67 |
+
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 68 |
+
|
| 69 |
+
# -------------------------
|
| 70 |
+
# GitHub API helpers
|
| 71 |
+
# -------------------------
|
| 72 |
+
|
| 73 |
+
def gh_session() -> requests.Session:
|
| 74 |
+
s = requests.Session()
|
| 75 |
+
s.headers.update({
|
| 76 |
+
"Accept": "application/vnd.github+json",
|
| 77 |
+
"User-Agent": "matrix-ai-rag-builder/1.0",
|
| 78 |
+
})
|
| 79 |
+
tok = os.getenv("GITHUB_TOKEN")
|
| 80 |
+
if tok:
|
| 81 |
+
s.headers["Authorization"] = f"Bearer {tok}"
|
| 82 |
+
return s
|
| 83 |
+
|
| 84 |
+
def gh_get_json(url: str, sess: requests.Session, max_retries: int = 3) -> Dict | List:
|
| 85 |
+
backoff = 1.0
|
| 86 |
+
for attempt in range(max_retries):
|
| 87 |
+
r = sess.get(url, timeout=25)
|
| 88 |
+
if r.status_code == 403 and "rate limit" in r.text.lower():
|
| 89 |
+
log.warning("GitHub rate-limited; sleeping %.1fs", backoff)
|
| 90 |
+
time.sleep(backoff)
|
| 91 |
+
backoff = min(backoff * 2, 30)
|
| 92 |
+
continue
|
| 93 |
+
r.raise_for_status()
|
| 94 |
+
return r.json()
|
| 95 |
+
r.raise_for_status()
|
| 96 |
+
return {}
|
| 97 |
+
|
| 98 |
+
def gh_list_org_repos(org: str, sess: requests.Session) -> List[Dict]:
|
| 99 |
+
repos: List[Dict] = []
|
| 100 |
+
page = 1
|
| 101 |
+
while True:
|
| 102 |
+
url = f"https://api.github.com/orgs/{org}/repos?per_page=100&page={page}"
|
| 103 |
+
js = gh_get_json(url, sess)
|
| 104 |
+
if not js:
|
| 105 |
+
break
|
| 106 |
+
repos.extend(js)
|
| 107 |
+
if len(js) < 100:
|
| 108 |
+
break
|
| 109 |
+
page += 1
|
| 110 |
+
return repos
|
| 111 |
+
|
| 112 |
+
def gh_list_tree(owner: str, repo: str, branch: str, sess: requests.Session) -> List[Dict]:
|
| 113 |
+
url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
|
| 114 |
+
js = gh_get_json(url, sess)
|
| 115 |
+
return js.get("tree", []) if isinstance(js, dict) else []
|
| 116 |
+
|
| 117 |
+
def gh_fetch_raw(owner: str, repo: str, branch: str, path: str, sess: requests.Session) -> Optional[str]:
|
| 118 |
+
raw_url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
|
| 119 |
+
r = sess.get(raw_url, timeout=25)
|
| 120 |
+
if r.status_code == 404 and branch == "main": # try master fallback
|
| 121 |
+
raw_url = f"https://raw.githubusercontent.com/{owner}/{repo}/master/{path}"
|
| 122 |
+
r = sess.get(raw_url, timeout=25)
|
| 123 |
+
if r.status_code == 200:
|
| 124 |
+
return r.text
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
# -------------------------
|
| 128 |
+
# Builders
|
| 129 |
+
# -------------------------
|
| 130 |
+
|
| 131 |
+
def ingest_github_repo(owner: str, name: str, branch: str, docs_paths: List[str],
|
| 132 |
+
include_readme: bool, exts: Tuple[str,...] = (".md",".mdx",".txt")) -> List[Tuple[str,str]]:
|
| 133 |
+
sess = gh_session()
|
| 134 |
+
out: List[Tuple[str,str]] = []
|
| 135 |
+
|
| 136 |
+
# README
|
| 137 |
+
if include_readme:
|
| 138 |
+
for candidate in ("README.md", "readme.md", "README.MD"):
|
| 139 |
+
t = gh_fetch_raw(owner, name, branch, candidate, sess)
|
| 140 |
+
if t:
|
| 141 |
+
out.append((f"github:{owner}/{name}/{candidate}", md_to_text(t)))
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
# Tree -> docs paths
|
| 145 |
+
tree = gh_list_tree(owner, name, branch, sess)
|
| 146 |
+
if not tree:
|
| 147 |
+
return out
|
| 148 |
+
|
| 149 |
+
wanted_dirs = [p.strip("/").lower() for p in docs_paths]
|
| 150 |
+
for entry in tree:
|
| 151 |
+
if entry.get("type") != "blob":
|
| 152 |
+
continue
|
| 153 |
+
path = entry.get("path", "")
|
| 154 |
+
lower = path.lower()
|
| 155 |
+
if not lower.endswith(exts):
|
| 156 |
+
continue
|
| 157 |
+
if any(lower.startswith(d + "/") for d in wanted_dirs):
|
| 158 |
+
t = gh_fetch_raw(owner, name, branch, path, sess)
|
| 159 |
+
if not t:
|
| 160 |
+
continue
|
| 161 |
+
txt = md_to_text(t) if lower.endswith((".md",".mdx")) else normalize_text(t)
|
| 162 |
+
if txt:
|
| 163 |
+
out.append((f"github:{owner}/{name}/{path}", txt))
|
| 164 |
+
return out
|
| 165 |
+
|
| 166 |
+
def ingest_github_sources(cfg: Dict) -> List[Tuple[str,str]]:
|
| 167 |
+
out: List[Tuple[str,str]] = []
|
| 168 |
+
gh = cfg.get("github") or {}
|
| 169 |
+
sess = gh_session()
|
| 170 |
+
|
| 171 |
+
# explicit repos
|
| 172 |
+
for repo in (gh.get("repos") or []):
|
| 173 |
+
owner = repo["owner"]
|
| 174 |
+
name = repo["name"]
|
| 175 |
+
branch = repo.get("branch", "main")
|
| 176 |
+
docs_paths = repo.get("docs_paths", ["docs"])
|
| 177 |
+
include_readme = bool(repo.get("include_readme", True))
|
| 178 |
+
out.extend(ingest_github_repo(owner, name, branch, docs_paths, include_readme))
|
| 179 |
+
|
| 180 |
+
# whole org scan (README + docs/)
|
| 181 |
+
for org in (gh.get("orgs") or []):
|
| 182 |
+
try:
|
| 183 |
+
repos = gh_list_org_repos(org, sess)
|
| 184 |
+
except Exception as e:
|
| 185 |
+
log.warning("Failed to list org %s: %s", org, e)
|
| 186 |
+
continue
|
| 187 |
+
for r in repos:
|
| 188 |
+
owner = r["owner"]["login"]
|
| 189 |
+
name = r["name"]
|
| 190 |
+
default_branch = r.get("default_branch", "main")
|
| 191 |
+
# README + docs/
|
| 192 |
+
out.extend(ingest_github_repo(owner, name, default_branch, ["docs"], include_readme=True))
|
| 193 |
+
return out
|
| 194 |
+
|
| 195 |
+
def ingest_local_sources(cfg: Dict) -> List[Tuple[str,str]]:
|
| 196 |
+
out: List[Tuple[str,str]] = []
|
| 197 |
+
local = cfg.get("local") or {}
|
| 198 |
+
paths = local.get("paths") or []
|
| 199 |
+
glob_pat = local.get("glob", "**/*.md")
|
| 200 |
+
for p in paths:
|
| 201 |
+
fp = Path(p)
|
| 202 |
+
if fp.is_file():
|
| 203 |
+
try:
|
| 204 |
+
raw = fp.read_text(encoding="utf-8", errors="ignore")
|
| 205 |
+
txt = md_to_text(raw) if fp.suffix.lower() in {".md",".mdx"} else normalize_text(raw)
|
| 206 |
+
if txt:
|
| 207 |
+
out.append((str(fp), txt))
|
| 208 |
+
except Exception as e:
|
| 209 |
+
log.warning("Failed reading %s: %s", fp, e)
|
| 210 |
+
elif fp.is_dir():
|
| 211 |
+
for f in fp.rglob(glob_pat):
|
| 212 |
+
try:
|
| 213 |
+
raw = f.read_text(encoding="utf-8", errors="ignore")
|
| 214 |
+
txt = md_to_text(raw) if f.suffix.lower() in {".md",".mdx"} else normalize_text(raw)
|
| 215 |
+
if txt:
|
| 216 |
+
out.append((str(f), txt))
|
| 217 |
+
except Exception as e:
|
| 218 |
+
log.warning("Failed reading %s: %s", f, e)
|
| 219 |
+
return out
|
| 220 |
+
|
| 221 |
+
def build_kb_from_config(config_path: str = "configs/rag_sources.yaml",
|
| 222 |
+
out_jsonl: str = "data/kb.jsonl",
|
| 223 |
+
max_chars: int = 800,
|
| 224 |
+
overlap: int = 120,
|
| 225 |
+
minlen: int = 200,
|
| 226 |
+
dedupe: bool = True) -> int:
|
| 227 |
+
cfg: Dict = {}
|
| 228 |
+
p = Path(config_path)
|
| 229 |
+
if p.exists():
|
| 230 |
+
cfg = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
| 231 |
+
else:
|
| 232 |
+
log.warning("rag_sources.yaml not found at %s (using defaults)", p)
|
| 233 |
+
|
| 234 |
+
records: List[Dict] = []
|
| 235 |
+
|
| 236 |
+
# GitHub
|
| 237 |
+
try:
|
| 238 |
+
gh_docs = ingest_github_sources(cfg)
|
| 239 |
+
for src, text in gh_docs:
|
| 240 |
+
for chunk in chunk_text(text, max_chars, overlap):
|
| 241 |
+
if len(chunk) >= minlen:
|
| 242 |
+
records.append({"text": chunk, "source": src})
|
| 243 |
+
except Exception as e:
|
| 244 |
+
log.warning("GitHub ingest failed: %s", e)
|
| 245 |
+
|
| 246 |
+
# Local
|
| 247 |
+
try:
|
| 248 |
+
loc_docs = ingest_local_sources(cfg)
|
| 249 |
+
for src, text in loc_docs:
|
| 250 |
+
for chunk in chunk_text(text, max_chars, overlap):
|
| 251 |
+
if len(chunk) >= minlen:
|
| 252 |
+
records.append({"text": chunk, "source": src})
|
| 253 |
+
except Exception as e:
|
| 254 |
+
log.warning("Local ingest failed: %s", e)
|
| 255 |
+
|
| 256 |
+
# URLs (optional)
|
| 257 |
+
for url in (cfg.get("urls") or []):
|
| 258 |
+
try:
|
| 259 |
+
r = requests.get(url, timeout=25)
|
| 260 |
+
r.raise_for_status()
|
| 261 |
+
txt = normalize_text(r.text)
|
| 262 |
+
for chunk in chunk_text(txt, max_chars, overlap):
|
| 263 |
+
if len(chunk) >= minlen:
|
| 264 |
+
records.append({"text": chunk, "source": url})
|
| 265 |
+
except Exception as e:
|
| 266 |
+
log.warning("URL ingest failed for %s: %s", url, e)
|
| 267 |
+
|
| 268 |
+
if dedupe:
|
| 269 |
+
seen = set()
|
| 270 |
+
deduped: List[Dict] = []
|
| 271 |
+
for rec in records:
|
| 272 |
+
h = hash(rec["text"])
|
| 273 |
+
if h in seen:
|
| 274 |
+
continue
|
| 275 |
+
seen.add(h)
|
| 276 |
+
deduped.append(rec)
|
| 277 |
+
records = deduped
|
| 278 |
+
|
| 279 |
+
if not records:
|
| 280 |
+
log.warning("No KB records produced.")
|
| 281 |
+
return 0
|
| 282 |
+
|
| 283 |
+
out_path = Path(out_jsonl)
|
| 284 |
+
write_jsonl(records, out_path)
|
| 285 |
+
log.info("Wrote %d chunks to %s", len(records), out_path)
|
| 286 |
+
return len(records)
|
| 287 |
+
|
| 288 |
+
def ensure_kb(out_jsonl: str = "data/kb.jsonl",
|
| 289 |
+
config_path: str = "configs/rag_sources.yaml",
|
| 290 |
+
skip_if_exists: bool = True) -> bool:
|
| 291 |
+
"""
|
| 292 |
+
If kb.jsonl exists -> return True.
|
| 293 |
+
Else -> build from config and return True on success.
|
| 294 |
+
"""
|
| 295 |
+
out = Path(out_jsonl)
|
| 296 |
+
if skip_if_exists and out.exists() and out.stat().st_size > 0:
|
| 297 |
+
log.info("KB already present at %s (skipping build)", out)
|
| 298 |
+
return True
|
| 299 |
+
n = build_kb_from_config(config_path=config_path, out_jsonl=out_jsonl)
|
| 300 |
+
return n > 0
|
app/main.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import logging
|
|
@@ -9,19 +10,16 @@ from typing import Any, Dict
|
|
| 9 |
from fastapi import FastAPI
|
| 10 |
from fastapi.responses import RedirectResponse
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
from .deps import get_settings
|
| 14 |
-
from .services.chat_service import get_retriever
|
| 15 |
-
|
| 16 |
-
# -----------------------------------------------------------------------------
|
| 17 |
-
# Early: load .env (so HF_TOKEN, ADMIN_TOKEN, etc. are available locally)
|
| 18 |
-
# -----------------------------------------------------------------------------
|
| 19 |
def _load_env_file(paths: list[str]) -> None:
|
| 20 |
-
"""
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
| 22 |
logger = logging.getLogger("uvicorn.error")
|
| 23 |
|
| 24 |
-
# 1) Try python-dotenv
|
| 25 |
try:
|
| 26 |
from dotenv import load_dotenv # type: ignore
|
| 27 |
for p in paths:
|
|
@@ -32,7 +30,7 @@ def _load_env_file(paths: list[str]) -> None:
|
|
| 32 |
logger.info("No .env file found in %s (skipping)", paths)
|
| 33 |
return
|
| 34 |
except Exception:
|
| 35 |
-
# 2) Fallback
|
| 36 |
for p in paths:
|
| 37 |
if not os.path.exists(p):
|
| 38 |
continue
|
|
@@ -53,7 +51,7 @@ def _load_env_file(paths: list[str]) -> None:
|
|
| 53 |
val.startswith("'") and val.endswith("'")
|
| 54 |
):
|
| 55 |
val = val[1:-1]
|
| 56 |
-
# do not clobber existing env (
|
| 57 |
os.environ.setdefault(key, val)
|
| 58 |
logger.info("Loaded environment from %s (fallback parser)", p)
|
| 59 |
return
|
|
@@ -62,13 +60,17 @@ def _load_env_file(paths: list[str]) -> None:
|
|
| 62 |
|
| 63 |
logger.info("No .env loaded (none found / parsers failed)")
|
| 64 |
|
| 65 |
-
# Try
|
| 66 |
_load_env_file([".env", "configs/.env", ".env.local", "configs/.env.local"])
|
| 67 |
|
| 68 |
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
try:
|
| 73 |
from .middleware import attach_middlewares # singular
|
| 74 |
except Exception:
|
|
@@ -80,11 +82,11 @@ except Exception:
|
|
| 80 |
"attach_middlewares not found; continuing without custom middlewares."
|
| 81 |
)
|
| 82 |
|
| 83 |
-
|
| 84 |
-
# Routers
|
| 85 |
-
# -----------------------------------------------------------------------------
|
| 86 |
from .routers import health, plan, chat
|
| 87 |
|
|
|
|
| 88 |
try:
|
| 89 |
from .ui import router as ui_router # type: ignore
|
| 90 |
HAS_UI = True
|
|
@@ -105,12 +107,27 @@ async def lifespan(app: FastAPI):
|
|
| 105 |
app.state.started_at = time.time()
|
| 106 |
app.state.version = os.getenv("APP_VERSION", "1.0.0")
|
| 107 |
|
| 108 |
-
# --- ADDED: Pre-load the RAG model and index on startup ---
|
| 109 |
logger = logging.getLogger("uvicorn.error")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
logger.info("Warming up RAG retriever...")
|
| 111 |
get_retriever(get_settings())
|
| 112 |
logger.info("RAG retriever is ready.")
|
| 113 |
-
|
|
|
|
| 114 |
hf_token_present = bool(os.getenv("HF_TOKEN"))
|
| 115 |
logger.info(
|
| 116 |
"matrix-ai starting (version=%s, port=%s, hf_token_present=%s)",
|
|
@@ -118,13 +135,12 @@ async def lifespan(app: FastAPI):
|
|
| 118 |
os.getenv("PORT", "7860"),
|
| 119 |
"yes" if hf_token_present else "no",
|
| 120 |
)
|
|
|
|
| 121 |
try:
|
| 122 |
yield
|
| 123 |
finally:
|
| 124 |
uptime = time.time() - getattr(app.state, "started_at", time.time())
|
| 125 |
-
logger.info(
|
| 126 |
-
"matrix-ai shutting down (uptime=%.2fs)", uptime
|
| 127 |
-
)
|
| 128 |
|
| 129 |
|
| 130 |
def create_app() -> FastAPI:
|
|
@@ -138,14 +154,19 @@ def create_app() -> FastAPI:
|
|
| 138 |
lifespan=lifespan,
|
| 139 |
)
|
| 140 |
|
|
|
|
| 141 |
attach_middlewares(app)
|
|
|
|
|
|
|
| 142 |
app.include_router(health.router, tags=["Health"])
|
| 143 |
app.include_router(plan.router, prefix="/v1", tags=["Planning"])
|
| 144 |
app.include_router(chat.router, prefix="/v1", tags=["Chat"])
|
| 145 |
|
|
|
|
| 146 |
if HAS_UI:
|
| 147 |
app.include_router(ui_router, tags=["UI"])
|
| 148 |
else:
|
|
|
|
| 149 |
@app.get("/", include_in_schema=False)
|
| 150 |
async def root() -> Dict[str, Any]:
|
| 151 |
return {
|
|
@@ -155,9 +176,12 @@ def create_app() -> FastAPI:
|
|
| 155 |
"docs": "/docs",
|
| 156 |
"endpoints": {"plan": "/v1/plan", "chat": "/v1/chat", "healthz": "/healthz"},
|
| 157 |
}
|
|
|
|
| 158 |
@app.get("/home", include_in_schema=False)
|
| 159 |
async def home_redirect():
|
| 160 |
return RedirectResponse(url="/docs", status_code=302)
|
|
|
|
| 161 |
return app
|
| 162 |
|
| 163 |
-
|
|
|
|
|
|
| 1 |
+
# app/main.py
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import logging
|
|
|
|
| 10 |
from fastapi import FastAPI
|
| 11 |
from fastapi.responses import RedirectResponse
|
| 12 |
|
| 13 |
+
# ---- Early env load (HF_TOKEN, ADMIN_TOKEN, GITHUB_TOKEN, etc.) ----
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def _load_env_file(paths: list[str]) -> None:
|
| 15 |
+
"""
|
| 16 |
+
Load environment variables from the first existing path in `paths`.
|
| 17 |
+
Prefer python-dotenv if present; otherwise use a tiny fallback parser.
|
| 18 |
+
Does not override pre-existing env vars (e.g., Space Secrets).
|
| 19 |
+
"""
|
| 20 |
logger = logging.getLogger("uvicorn.error")
|
| 21 |
|
| 22 |
+
# 1) Try python-dotenv
|
| 23 |
try:
|
| 24 |
from dotenv import load_dotenv # type: ignore
|
| 25 |
for p in paths:
|
|
|
|
| 30 |
logger.info("No .env file found in %s (skipping)", paths)
|
| 31 |
return
|
| 32 |
except Exception:
|
| 33 |
+
# 2) Fallback minimal parser
|
| 34 |
for p in paths:
|
| 35 |
if not os.path.exists(p):
|
| 36 |
continue
|
|
|
|
| 51 |
val.startswith("'") and val.endswith("'")
|
| 52 |
):
|
| 53 |
val = val[1:-1]
|
| 54 |
+
# do not clobber existing env (e.g., HF Secrets)
|
| 55 |
os.environ.setdefault(key, val)
|
| 56 |
logger.info("Loaded environment from %s (fallback parser)", p)
|
| 57 |
return
|
|
|
|
| 60 |
|
| 61 |
logger.info("No .env loaded (none found / parsers failed)")
|
| 62 |
|
| 63 |
+
# Try common local locations. HF Spaces will rely on Secrets instead.
|
| 64 |
_load_env_file([".env", "configs/.env", ".env.local", "configs/.env.local"])
|
| 65 |
|
| 66 |
|
| 67 |
+
# ---- RAG bootstrap & warm-up ----
|
| 68 |
+
from .deps import get_settings
|
| 69 |
+
from .services.chat_service import get_retriever
|
| 70 |
+
from .core.rag.build import ensure_kb
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ---- Middlewares ----
|
| 74 |
try:
|
| 75 |
from .middleware import attach_middlewares # singular
|
| 76 |
except Exception:
|
|
|
|
| 82 |
"attach_middlewares not found; continuing without custom middlewares."
|
| 83 |
)
|
| 84 |
|
| 85 |
+
|
| 86 |
+
# ---- Routers ----
|
|
|
|
| 87 |
from .routers import health, plan, chat
|
| 88 |
|
| 89 |
+
# Optional UI bundle (/, /chat, /dev)
|
| 90 |
try:
|
| 91 |
from .ui import router as ui_router # type: ignore
|
| 92 |
HAS_UI = True
|
|
|
|
| 107 |
app.state.started_at = time.time()
|
| 108 |
app.state.version = os.getenv("APP_VERSION", "1.0.0")
|
| 109 |
|
|
|
|
| 110 |
logger = logging.getLogger("uvicorn.error")
|
| 111 |
+
|
| 112 |
+
# 1) Build KB on first boot (skips if already present)
|
| 113 |
+
try:
|
| 114 |
+
if ensure_kb(
|
| 115 |
+
out_jsonl="data/kb.jsonl",
|
| 116 |
+
config_path="configs/rag_sources.yaml",
|
| 117 |
+
skip_if_exists=True,
|
| 118 |
+
):
|
| 119 |
+
logger.info("KB ready at data/kb.jsonl")
|
| 120 |
+
else:
|
| 121 |
+
logger.warning("KB build produced no records; running LLM-only.")
|
| 122 |
+
except Exception as e:
|
| 123 |
+
logger.warning("KB build failed (%s); running LLM-only.", e)
|
| 124 |
+
|
| 125 |
+
# 2) Warm up RAG retriever (indexes data/kb.jsonl if present)
|
| 126 |
logger.info("Warming up RAG retriever...")
|
| 127 |
get_retriever(get_settings())
|
| 128 |
logger.info("RAG retriever is ready.")
|
| 129 |
+
|
| 130 |
+
# 3) Boot log
|
| 131 |
hf_token_present = bool(os.getenv("HF_TOKEN"))
|
| 132 |
logger.info(
|
| 133 |
"matrix-ai starting (version=%s, port=%s, hf_token_present=%s)",
|
|
|
|
| 135 |
os.getenv("PORT", "7860"),
|
| 136 |
"yes" if hf_token_present else "no",
|
| 137 |
)
|
| 138 |
+
|
| 139 |
try:
|
| 140 |
yield
|
| 141 |
finally:
|
| 142 |
uptime = time.time() - getattr(app.state, "started_at", time.time())
|
| 143 |
+
logger.info("matrix-ai shutting down (uptime=%.2fs)", uptime)
|
|
|
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
def create_app() -> FastAPI:
|
|
|
|
| 154 |
lifespan=lifespan,
|
| 155 |
)
|
| 156 |
|
| 157 |
+
# Middlewares (gzip, CORS, rate-limit, req-logs, etc.)
|
| 158 |
attach_middlewares(app)
|
| 159 |
+
|
| 160 |
+
# Core routers
|
| 161 |
app.include_router(health.router, tags=["Health"])
|
| 162 |
app.include_router(plan.router, prefix="/v1", tags=["Planning"])
|
| 163 |
app.include_router(chat.router, prefix="/v1", tags=["Chat"])
|
| 164 |
|
| 165 |
+
# UI (/, /chat, /dev). Your ui.py already defines "/" → /chat
|
| 166 |
if HAS_UI:
|
| 167 |
app.include_router(ui_router, tags=["UI"])
|
| 168 |
else:
|
| 169 |
+
# Minimal root so HF root probes pass even without UI
|
| 170 |
@app.get("/", include_in_schema=False)
|
| 171 |
async def root() -> Dict[str, Any]:
|
| 172 |
return {
|
|
|
|
| 176 |
"docs": "/docs",
|
| 177 |
"endpoints": {"plan": "/v1/plan", "chat": "/v1/chat", "healthz": "/healthz"},
|
| 178 |
}
|
| 179 |
+
|
| 180 |
@app.get("/home", include_in_schema=False)
|
| 181 |
async def home_redirect():
|
| 182 |
return RedirectResponse(url="/docs", status_code=302)
|
| 183 |
+
|
| 184 |
return app
|
| 185 |
|
| 186 |
+
|
| 187 |
+
app = create_app()
|
app/middleware.py
CHANGED
|
@@ -1,30 +1,68 @@
|
|
| 1 |
import time
|
| 2 |
import logging
|
|
|
|
| 3 |
from typing import Callable
|
| 4 |
from fastapi import FastAPI, Request, Response
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
from starlette.middleware.gzip import GZipMiddleware
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from .deps import get_settings
|
| 9 |
from .core.rate_limit import RateLimiter
|
| 10 |
from .core.logging import add_trace_id
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# Setup structured logging
|
| 13 |
logger = logging.getLogger("matrix-ai")
|
| 14 |
if not logger.handlers:
|
| 15 |
logger.setLevel(logging.INFO)
|
| 16 |
handler = logging.StreamHandler()
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
handler.setFormatter(formatter)
|
| 21 |
logger.addHandler(handler)
|
| 22 |
|
| 23 |
_rate_limiter = RateLimiter()
|
| 24 |
|
| 25 |
-
def attach_middlewares(app: FastAPI):
|
| 26 |
"""Attaches all required middlewares to the FastAPI app."""
|
|
|
|
|
|
|
| 27 |
app.add_middleware(GZipMiddleware, minimum_size=512)
|
|
|
|
| 28 |
app.add_middleware(
|
| 29 |
CORSMiddleware,
|
| 30 |
allow_origins=["*"],
|
|
@@ -35,20 +73,25 @@ def attach_middlewares(app: FastAPI):
|
|
| 35 |
|
| 36 |
@app.middleware("http")
|
| 37 |
async def rate_limit_and_log_middleware(request: Request, call_next: Callable):
|
|
|
|
| 38 |
add_trace_id(request)
|
|
|
|
| 39 |
settings = get_settings()
|
| 40 |
client_ip = request.client.host if request.client else "unknown"
|
| 41 |
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
| 43 |
return Response(status_code=429, content="Rate limit exceeded")
|
| 44 |
|
| 45 |
start_time = time.time()
|
| 46 |
response = await call_next(request)
|
| 47 |
-
process_time = (time.time() - start_time) * 1000
|
| 48 |
response.headers["X-Process-Time-Ms"] = f"{process_time:.2f}"
|
| 49 |
|
| 50 |
logger.info(
|
| 51 |
f'"{request.method} {request.url.path}" {response.status_code}',
|
| 52 |
-
extra={
|
| 53 |
)
|
| 54 |
return response
|
|
|
|
| 1 |
import time
|
| 2 |
import logging
|
| 3 |
+
import json
|
| 4 |
from typing import Callable
|
| 5 |
from fastapi import FastAPI, Request, Response
|
| 6 |
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
from starlette.middleware.gzip import GZipMiddleware
|
| 8 |
+
|
| 9 |
+
# Try to import python-json-logger; fall back to a tiny JSON formatter if missing.
|
| 10 |
+
try:
|
| 11 |
+
from pythonjsonlogger import jsonlogger # type: ignore[import-not-found]
|
| 12 |
+
_HAS_PY_JSON_LOGGER = True
|
| 13 |
+
except Exception:
|
| 14 |
+
_HAS_PY_JSON_LOGGER = False
|
| 15 |
+
|
| 16 |
from .deps import get_settings
|
| 17 |
from .core.rate_limit import RateLimiter
|
| 18 |
from .core.logging import add_trace_id
|
| 19 |
|
| 20 |
+
# ---- Fallback JSON formatter (if python-json-logger isn't available) ----
|
| 21 |
+
class _SimpleJsonFormatter(logging.Formatter):
|
| 22 |
+
def format(self, record: logging.LogRecord) -> str:
|
| 23 |
+
payload = {
|
| 24 |
+
"asctime": self.formatTime(record, "%Y-%m-%d %H:%M:%S"),
|
| 25 |
+
"name": record.name,
|
| 26 |
+
"levelname": record.levelname,
|
| 27 |
+
"message": record.getMessage(),
|
| 28 |
+
# We attach trace_id via logger.info(..., extra={"trace_id": "..."}).
|
| 29 |
+
"trace_id": getattr(record, "trace_id", None),
|
| 30 |
+
}
|
| 31 |
+
try:
|
| 32 |
+
return json.dumps(payload, ensure_ascii=False)
|
| 33 |
+
except Exception:
|
| 34 |
+
# Last-ditch plain log if JSON serialization ever fails
|
| 35 |
+
return (
|
| 36 |
+
f'{payload["asctime"]} {payload["name"]} {payload["levelname"]} '
|
| 37 |
+
f'{payload["message"]} trace_id={payload["trace_id"]}'
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
# Setup structured logging
|
| 41 |
logger = logging.getLogger("matrix-ai")
|
| 42 |
if not logger.handlers:
|
| 43 |
logger.setLevel(logging.INFO)
|
| 44 |
handler = logging.StreamHandler()
|
| 45 |
+
if _HAS_PY_JSON_LOGGER:
|
| 46 |
+
# Same fields you had; python-json-logger builds JSON from this format string
|
| 47 |
+
formatter = jsonlogger.JsonFormatter(
|
| 48 |
+
"%(asctime)s %(name)s %(levelname)s %(message)s %(trace_id)s"
|
| 49 |
+
)
|
| 50 |
+
else:
|
| 51 |
+
formatter = _SimpleJsonFormatter()
|
| 52 |
+
logging.getLogger("uvicorn.error").warning(
|
| 53 |
+
"python-json-logger not found; using a minimal JSON formatter."
|
| 54 |
+
)
|
| 55 |
handler.setFormatter(formatter)
|
| 56 |
logger.addHandler(handler)
|
| 57 |
|
| 58 |
_rate_limiter = RateLimiter()
|
| 59 |
|
| 60 |
+
def attach_middlewares(app: FastAPI) -> None:
|
| 61 |
"""Attaches all required middlewares to the FastAPI app."""
|
| 62 |
+
# NOTE: We keep GZip, but your SSE endpoints already set `Content-Encoding: identity`
|
| 63 |
+
# so they won't be buffered/compressed.
|
| 64 |
app.add_middleware(GZipMiddleware, minimum_size=512)
|
| 65 |
+
|
| 66 |
app.add_middleware(
|
| 67 |
CORSMiddleware,
|
| 68 |
allow_origins=["*"],
|
|
|
|
| 73 |
|
| 74 |
@app.middleware("http")
|
| 75 |
async def rate_limit_and_log_middleware(request: Request, call_next: Callable):
|
| 76 |
+
# Attach per-request trace id
|
| 77 |
add_trace_id(request)
|
| 78 |
+
|
| 79 |
settings = get_settings()
|
| 80 |
client_ip = request.client.host if request.client else "unknown"
|
| 81 |
|
| 82 |
+
# Simple fixed-window limiter
|
| 83 |
+
if not _rate_limiter.allow(
|
| 84 |
+
client_ip, request.url.path, settings.limits.rate_per_min
|
| 85 |
+
):
|
| 86 |
return Response(status_code=429, content="Rate limit exceeded")
|
| 87 |
|
| 88 |
start_time = time.time()
|
| 89 |
response = await call_next(request)
|
| 90 |
+
process_time = (time.time() - start_time) * 1000.0
|
| 91 |
response.headers["X-Process-Time-Ms"] = f"{process_time:.2f}"
|
| 92 |
|
| 93 |
logger.info(
|
| 94 |
f'"{request.method} {request.url.path}" {response.status_code}',
|
| 95 |
+
extra={"trace_id": getattr(request.state, "trace_id", "N/A")},
|
| 96 |
)
|
| 97 |
return response
|
app/routers/chat.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import json
|
| 4 |
-
from typing import Any,
|
| 5 |
|
| 6 |
from fastapi import APIRouter, Depends, HTTPException, Query
|
| 7 |
from pydantic import BaseModel, Field
|
| 8 |
-
from starlette.concurrency import run_in_threadpool
|
| 9 |
from starlette.responses import StreamingResponse
|
| 10 |
|
| 11 |
from ..deps import get_settings
|
|
@@ -27,12 +27,16 @@ class ChatRequest(BaseModel):
|
|
| 27 |
messages: Optional[List[ChatMessage]] = None
|
| 28 |
|
| 29 |
def as_text(self) -> str:
|
| 30 |
-
if self.query:
|
| 31 |
-
|
| 32 |
-
if self.
|
|
|
|
|
|
|
|
|
|
| 33 |
if self.messages:
|
| 34 |
for m in reversed(self.messages):
|
| 35 |
-
if m.role.lower() == "user":
|
|
|
|
| 36 |
return self.messages[-1].content
|
| 37 |
raise ValueError("Body must include 'query'/'question'/'prompt' or 'messages'")
|
| 38 |
|
|
@@ -50,7 +54,7 @@ async def chat(req: ChatRequest, settings: Settings = Depends(get_settings)):
|
|
| 50 |
raise HTTPException(status_code=422, detail=str(e))
|
| 51 |
svc = ChatService(settings)
|
| 52 |
try:
|
| 53 |
-
#
|
| 54 |
answer, sources = await run_in_threadpool(svc.answer_with_sources, text)
|
| 55 |
return ChatResponse(answer=answer, sources=sources)
|
| 56 |
except PermissionError as e:
|
|
@@ -63,7 +67,6 @@ async def chat(req: ChatRequest, settings: Settings = Depends(get_settings)):
|
|
| 63 |
async def chat_get(query: str = Query(...), settings: Settings = Depends(get_settings)):
|
| 64 |
svc = ChatService(settings)
|
| 65 |
try:
|
| 66 |
-
# Run the blocking call in a thread pool
|
| 67 |
answer, sources = await run_in_threadpool(svc.answer_with_sources, query)
|
| 68 |
return ChatResponse(answer=answer, sources=sources)
|
| 69 |
except PermissionError as e:
|
|
@@ -72,7 +75,6 @@ async def chat_get(query: str = Query(...), settings: Settings = Depends(get_set
|
|
| 72 |
raise HTTPException(status_code=502, detail=f"Inference error: {e}")
|
| 73 |
|
| 74 |
|
| 75 |
-
# ---------- Streaming (SSE) ----------
|
| 76 |
def _sse_line(obj: Any) -> str:
|
| 77 |
payload = obj if isinstance(obj, str) else json.dumps(obj, ensure_ascii=False)
|
| 78 |
return f"data: {payload}\n\n"
|
|
@@ -80,25 +82,29 @@ def _sse_line(obj: Any) -> str:
|
|
| 80 |
|
| 81 |
@router.get("/chat/stream")
|
| 82 |
async def chat_stream(query: str = Query(...), settings: Settings = Depends(get_settings)):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
svc = ChatService(settings)
|
| 84 |
|
| 85 |
-
|
| 86 |
-
#
|
| 87 |
yield ":" + (" " * 2048) + "\n\n"
|
|
|
|
| 88 |
yield "event: ping\ndata: 0\n\n"
|
| 89 |
-
|
|
|
|
| 90 |
try:
|
| 91 |
-
|
| 92 |
-
stream_generator = await run_in_threadpool(svc.stream_answer, query)
|
| 93 |
-
any_tokens = False
|
| 94 |
-
for token in stream_generator:
|
| 95 |
if token:
|
| 96 |
any_tokens = True
|
| 97 |
yield _sse_line({"delta": token})
|
| 98 |
-
|
| 99 |
if not any_tokens:
|
| 100 |
yield _sse_line({"delta": ""})
|
| 101 |
yield _sse_line("[DONE]")
|
|
|
|
|
|
|
| 102 |
except Exception as e:
|
| 103 |
yield _sse_line({"error": str(e)})
|
| 104 |
|
|
@@ -108,7 +114,12 @@ async def chat_stream(query: str = Query(...), settings: Settings = Depends(get_
|
|
| 108 |
"Connection": "keep-alive",
|
| 109 |
"Content-Encoding": "identity",
|
| 110 |
}
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
@router.post("/chat/stream")
|
|
@@ -117,4 +128,4 @@ async def chat_stream_post(req: ChatRequest, settings: Settings = Depends(get_se
|
|
| 117 |
q = req.as_text()
|
| 118 |
except ValueError as e:
|
| 119 |
raise HTTPException(status_code=422, detail=str(e))
|
| 120 |
-
return await chat_stream(query=q, settings=settings)
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import json
|
| 4 |
+
from typing import Any, Iterator, List, Optional
|
| 5 |
|
| 6 |
from fastapi import APIRouter, Depends, HTTPException, Query
|
| 7 |
from pydantic import BaseModel, Field
|
| 8 |
+
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
| 9 |
from starlette.responses import StreamingResponse
|
| 10 |
|
| 11 |
from ..deps import get_settings
|
|
|
|
| 27 |
messages: Optional[List[ChatMessage]] = None
|
| 28 |
|
| 29 |
def as_text(self) -> str:
|
| 30 |
+
if self.query:
|
| 31 |
+
return self.query
|
| 32 |
+
if self.question:
|
| 33 |
+
return self.question
|
| 34 |
+
if self.prompt:
|
| 35 |
+
return self.prompt
|
| 36 |
if self.messages:
|
| 37 |
for m in reversed(self.messages):
|
| 38 |
+
if m.role.lower() == "user":
|
| 39 |
+
return m.content
|
| 40 |
return self.messages[-1].content
|
| 41 |
raise ValueError("Body must include 'query'/'question'/'prompt' or 'messages'")
|
| 42 |
|
|
|
|
| 54 |
raise HTTPException(status_code=422, detail=str(e))
|
| 55 |
svc = ChatService(settings)
|
| 56 |
try:
|
| 57 |
+
# run blocking client in a threadpool
|
| 58 |
answer, sources = await run_in_threadpool(svc.answer_with_sources, text)
|
| 59 |
return ChatResponse(answer=answer, sources=sources)
|
| 60 |
except PermissionError as e:
|
|
|
|
| 67 |
async def chat_get(query: str = Query(...), settings: Settings = Depends(get_settings)):
|
| 68 |
svc = ChatService(settings)
|
| 69 |
try:
|
|
|
|
| 70 |
answer, sources = await run_in_threadpool(svc.answer_with_sources, query)
|
| 71 |
return ChatResponse(answer=answer, sources=sources)
|
| 72 |
except PermissionError as e:
|
|
|
|
| 75 |
raise HTTPException(status_code=502, detail=f"Inference error: {e}")
|
| 76 |
|
| 77 |
|
|
|
|
| 78 |
def _sse_line(obj: Any) -> str:
|
| 79 |
payload = obj if isinstance(obj, str) else json.dumps(obj, ensure_ascii=False)
|
| 80 |
return f"data: {payload}\n\n"
|
|
|
|
| 82 |
|
| 83 |
@router.get("/chat/stream")
|
| 84 |
async def chat_stream(query: str = Query(...), settings: Settings = Depends(get_settings)):
|
| 85 |
+
"""
|
| 86 |
+
SSE of token deltas. We iterate the sync streaming client in a threadpool
|
| 87 |
+
so the event loop stays free.
|
| 88 |
+
"""
|
| 89 |
svc = ChatService(settings)
|
| 90 |
|
| 91 |
+
def sync_stream() -> Iterator[str]:
|
| 92 |
+
# send anti-buffer padding + ping immediately
|
| 93 |
yield ":" + (" " * 2048) + "\n\n"
|
| 94 |
+
yield "retry: 1500\n\n"
|
| 95 |
yield "event: ping\ndata: 0\n\n"
|
| 96 |
+
|
| 97 |
+
any_tokens = False
|
| 98 |
try:
|
| 99 |
+
for token in svc.stream_answer(query):
|
|
|
|
|
|
|
|
|
|
| 100 |
if token:
|
| 101 |
any_tokens = True
|
| 102 |
yield _sse_line({"delta": token})
|
|
|
|
| 103 |
if not any_tokens:
|
| 104 |
yield _sse_line({"delta": ""})
|
| 105 |
yield _sse_line("[DONE]")
|
| 106 |
+
except GeneratorExit:
|
| 107 |
+
return
|
| 108 |
except Exception as e:
|
| 109 |
yield _sse_line({"error": str(e)})
|
| 110 |
|
|
|
|
| 114 |
"Connection": "keep-alive",
|
| 115 |
"Content-Encoding": "identity",
|
| 116 |
}
|
| 117 |
+
# iterate the sync generator in a threadpool (non-blocking for the loop)
|
| 118 |
+
return StreamingResponse(
|
| 119 |
+
iterate_in_threadpool(sync_stream()),
|
| 120 |
+
media_type="text/event-stream; charset=utf-8",
|
| 121 |
+
headers=headers,
|
| 122 |
+
)
|
| 123 |
|
| 124 |
|
| 125 |
@router.post("/chat/stream")
|
|
|
|
| 128 |
q = req.as_text()
|
| 129 |
except ValueError as e:
|
| 130 |
raise HTTPException(status_code=422, detail=str(e))
|
| 131 |
+
return await chat_stream(query=q, settings=settings)
|
app/services/chat_service.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import logging
|
| 4 |
import os
|
|
|
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import List, Tuple
|
| 7 |
|
| 8 |
from ..core.config import Settings
|
| 9 |
from ..core.inference.client import RouterRequestsClient
|
|
@@ -11,33 +14,203 @@ from ..core.rag.retriever import Retriever
|
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
SYSTEM_PROMPT = (
|
| 15 |
-
"You are MATRIX-AI, a concise, helpful assistant for the Matrix EcoSystem
|
| 16 |
-
"
|
|
|
|
| 17 |
)
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
_retriever_instance: Retriever
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
def get_retriever(settings: Settings) -> Retriever
|
| 23 |
-
"""
|
| 24 |
global _retriever_instance
|
| 25 |
if _retriever_instance is not None:
|
| 26 |
return _retriever_instance
|
| 27 |
|
| 28 |
kb_path = os.getenv("RAG_KB_PATH", "data/kb.jsonl")
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
_retriever_instance = Retriever(kb_path=kb_path, top_k=settings.rag.top_k)
|
| 32 |
logger.info("RAG enabled with KB at %s (top_k=%d)", kb_path, settings.rag.top_k)
|
| 33 |
-
|
| 34 |
-
logger.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
except Exception as e:
|
| 36 |
-
logger.warning("
|
|
|
|
| 37 |
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
class ChatService:
|
| 42 |
def __init__(self, settings: Settings):
|
| 43 |
self.settings = settings
|
|
@@ -46,51 +219,80 @@ class ChatService:
|
|
| 46 |
fallback=settings.model.fallback,
|
| 47 |
provider=getattr(settings.model, "provider", None),
|
| 48 |
max_retries=2,
|
|
|
|
|
|
|
| 49 |
)
|
| 50 |
-
#
|
| 51 |
self.retriever = get_retriever(settings)
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
if not self.retriever:
|
| 55 |
return "", []
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
return "", []
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
sources =
|
| 62 |
-
return
|
| 63 |
|
| 64 |
def _augment(self, query: str) -> Tuple[str, List[str]]:
|
| 65 |
"""
|
| 66 |
Build the final user message (with optional CONTEXT) and return sources.
|
| 67 |
"""
|
| 68 |
-
ctx, sources = self.
|
| 69 |
-
|
| 70 |
-
# --- THIS IS THE CORRECTED PROMPT ---
|
| 71 |
if ctx:
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
| 74 |
else:
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
return augmented, sources
|
| 79 |
|
| 80 |
-
#
|
| 81 |
def answer_with_sources(self, query: str) -> Tuple[str, List[str]]:
|
| 82 |
user_msg, sources = self._augment(query)
|
| 83 |
text = self.client.chat_nonstream(
|
| 84 |
-
SYSTEM_PROMPT,
|
|
|
|
| 85 |
max_tokens=self.settings.model.max_new_tokens,
|
| 86 |
temperature=self.settings.model.temperature,
|
| 87 |
)
|
| 88 |
return text, sources
|
| 89 |
|
|
|
|
| 90 |
def stream_answer(self, query: str):
|
| 91 |
user_msg, _ = self._augment(query)
|
|
|
|
| 92 |
return self.client.chat_stream(
|
| 93 |
-
SYSTEM_PROMPT,
|
|
|
|
| 94 |
max_tokens=self.settings.model.max_new_tokens,
|
| 95 |
temperature=self.settings.model.temperature,
|
| 96 |
-
)
|
|
|
|
| 1 |
+
# app/services/chat_service.py
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import logging
|
| 5 |
import os
|
| 6 |
+
import re
|
| 7 |
+
import threading
|
| 8 |
from pathlib import Path
|
| 9 |
+
from typing import List, Tuple, Dict, Optional
|
| 10 |
|
| 11 |
from ..core.config import Settings
|
| 12 |
from ..core.inference.client import RouterRequestsClient
|
|
|
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
+
# --- Optional cross-encoder reranker (graceful fallback) ---
|
| 18 |
+
try:
|
| 19 |
+
from sentence_transformers import CrossEncoder # type: ignore
|
| 20 |
+
except Exception: # pragma: no cover
|
| 21 |
+
CrossEncoder = None # type: ignore
|
| 22 |
+
|
| 23 |
SYSTEM_PROMPT = (
|
| 24 |
+
"You are MATRIX-AI, a concise, helpful assistant for the Matrix EcoSystem.\n"
|
| 25 |
+
"You MUST use the provided CONTEXT. If an answer is not supported by the context, say you don't know.\n"
|
| 26 |
+
"Prefer short, clear sentences. Include product/feature names exactly as written in the context.\n"
|
| 27 |
)
|
| 28 |
|
| 29 |
+
# Thread-safe singleton retriever
|
| 30 |
+
_retriever_instance: Optional[Retriever] = None
|
| 31 |
+
_retriever_lock = threading.Lock()
|
| 32 |
+
|
| 33 |
|
| 34 |
+
def get_retriever(settings: Settings) -> Optional[Retriever]:
|
| 35 |
+
"""Initialize and return a single Retriever instance (double-checked locking)."""
|
| 36 |
global _retriever_instance
|
| 37 |
if _retriever_instance is not None:
|
| 38 |
return _retriever_instance
|
| 39 |
|
| 40 |
kb_path = os.getenv("RAG_KB_PATH", "data/kb.jsonl")
|
| 41 |
+
if not Path(kb_path).exists():
|
| 42 |
+
logger.info("RAG KB not found at %s — running LLM-only.", kb_path)
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
with _retriever_lock:
|
| 46 |
+
if _retriever_instance is not None:
|
| 47 |
+
return _retriever_instance
|
| 48 |
+
try:
|
| 49 |
_retriever_instance = Retriever(kb_path=kb_path, top_k=settings.rag.top_k)
|
| 50 |
logger.info("RAG enabled with KB at %s (top_k=%d)", kb_path, settings.rag.top_k)
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.warning("RAG disabled (failed to initialize Retriever: %s)", e)
|
| 53 |
+
_retriever_instance = None
|
| 54 |
+
return _retriever_instance
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ----------------------------
|
| 58 |
+
# RAG utilities (ranking & snippets)
|
| 59 |
+
# ----------------------------
|
| 60 |
+
|
| 61 |
+
_ALIAS_TABLE: Dict[str, List[str]] = {
|
| 62 |
+
# canonical -> aliases
|
| 63 |
+
"matrixhub": ["matrix hub", "matrixhub", "hub api", "catalog", "registry", "cas"],
|
| 64 |
+
"mcp": ["model context protocol", "mcp", "manifest", "server manifest", "admin api"],
|
| 65 |
+
"agent-matrix": ["agent-matrix", "matrix agents", "matrix ecosystem", "matrix toolkit"],
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _normalize(text: str) -> List[str]:
|
| 72 |
+
return [t.lower() for t in _WORD_RE.findall(text)]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _expand_query(q: str) -> str:
|
| 76 |
+
"""Add domain aliases to help the embedding retrieve the right docs."""
|
| 77 |
+
ql = q.lower()
|
| 78 |
+
extras: List[str] = []
|
| 79 |
+
for canon, variants in _ALIAS_TABLE.items():
|
| 80 |
+
if any(v in ql for v in variants):
|
| 81 |
+
extras.extend([canon] + variants)
|
| 82 |
+
if extras:
|
| 83 |
+
return q + " | " + " ".join(sorted(set(extras)))
|
| 84 |
+
return q
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _keyword_overlap_score(query: str, text: str) -> float:
|
| 88 |
+
"""Simple lexical grounding score: Jaccard over unique tokens (stopword-agnostic)."""
|
| 89 |
+
q_tokens = set(_normalize(query))
|
| 90 |
+
d_tokens = set(_normalize(text))
|
| 91 |
+
if not q_tokens or not d_tokens:
|
| 92 |
+
return 0.0
|
| 93 |
+
inter = len(q_tokens & d_tokens)
|
| 94 |
+
union = len(q_tokens | d_tokens)
|
| 95 |
+
return inter / max(1, union)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _domain_boost(text: str) -> float:
|
| 99 |
+
t = text.lower()
|
| 100 |
+
boost = 0.0
|
| 101 |
+
for term in ("matrixhub", "hub api", "catalog", "mcp", "server manifest", "cas"):
|
| 102 |
+
if term in t:
|
| 103 |
+
boost += 0.05
|
| 104 |
+
return min(boost, 0.25)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _best_paragraphs(text: str, query: str, max_chars: int = 700) -> str:
|
| 108 |
+
"""
|
| 109 |
+
Split by blank lines and pick 1-2 best paragraphs by lexical overlap.
|
| 110 |
+
Keep it compact to avoid swamping the LLM.
|
| 111 |
+
"""
|
| 112 |
+
paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
|
| 113 |
+
if not paras:
|
| 114 |
+
return text[:max_chars]
|
| 115 |
+
|
| 116 |
+
scored = [(p, _keyword_overlap_score(query, p)) for p in paras]
|
| 117 |
+
scored.sort(key=lambda x: x[1], reverse=True)
|
| 118 |
+
picked: List[str] = []
|
| 119 |
+
used = 0
|
| 120 |
+
for p, _s in scored[:4]:
|
| 121 |
+
if used >= max_chars:
|
| 122 |
+
break
|
| 123 |
+
picked.append(p)
|
| 124 |
+
used += len(p) + 2
|
| 125 |
+
if used >= max_chars or len(picked) >= 2:
|
| 126 |
+
break
|
| 127 |
+
return "\n".join(picked)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _cross_encoder_scores(
|
| 131 |
+
model: Optional["CrossEncoder"],
|
| 132 |
+
query: str,
|
| 133 |
+
docs: List[Dict],
|
| 134 |
+
max_pairs: int = 50,
|
| 135 |
+
) -> Optional[List[float]]:
|
| 136 |
+
if not model:
|
| 137 |
+
return None
|
| 138 |
+
try:
|
| 139 |
+
pairs = [(query, d["text"][:1200]) for d in docs[:max_pairs]]
|
| 140 |
+
return list(model.predict(pairs))
|
| 141 |
except Exception as e:
|
| 142 |
+
logger.warning("Cross-encoder scoring failed; continuing without it (%s)", e)
|
| 143 |
+
return None
|
| 144 |
|
| 145 |
+
|
| 146 |
+
def _rerank_docs(
|
| 147 |
+
docs: List[Dict],
|
| 148 |
+
query: str,
|
| 149 |
+
k_final: int,
|
| 150 |
+
reranker: Optional["CrossEncoder"] = None,
|
| 151 |
+
) -> List[Dict]:
|
| 152 |
+
"""
|
| 153 |
+
Combine vector score, keyword overlap, domain boost, and optional cross-encoder.
|
| 154 |
+
Score = 0.55*vec + 0.35*lex + 0.10*boost (+ 0.20*ce if available, rescaled).
|
| 155 |
+
"""
|
| 156 |
+
if not docs:
|
| 157 |
+
return []
|
| 158 |
+
|
| 159 |
+
# Normalize vector scores (cosine similarities 0..1-ish) to 0..1
|
| 160 |
+
vec_scores = [float(d.get("score", 0.0)) for d in docs]
|
| 161 |
+
if vec_scores:
|
| 162 |
+
vmin = min(vec_scores)
|
| 163 |
+
vmax = max(vec_scores)
|
| 164 |
+
rng = max(1e-6, (vmax - vmin))
|
| 165 |
+
vec_norm = [(v - vmin) / rng for v in vec_scores]
|
| 166 |
+
else:
|
| 167 |
+
vec_norm = [0.0] * len(docs)
|
| 168 |
+
|
| 169 |
+
lex_scores = [_keyword_overlap_score(query, d["text"]) for d in docs]
|
| 170 |
+
boosts = [_domain_boost(d["text"]) for d in docs]
|
| 171 |
+
|
| 172 |
+
ce_scores = _cross_encoder_scores(reranker, query, docs)
|
| 173 |
+
if ce_scores:
|
| 174 |
+
# Min-max normalize CE too
|
| 175 |
+
cmin, cmax = min(ce_scores), max(ce_scores)
|
| 176 |
+
crng = max(1e-6, (cmax - cmin))
|
| 177 |
+
ce_norm = [(c - cmin) / crng for c in ce_scores]
|
| 178 |
+
else:
|
| 179 |
+
ce_norm = None
|
| 180 |
+
|
| 181 |
+
merged: List[Tuple[float, Dict]] = []
|
| 182 |
+
for i, d in enumerate(docs):
|
| 183 |
+
score = 0.55 * vec_norm[i] + 0.35 * lex_scores[i] + 0.10 * boosts[i]
|
| 184 |
+
if ce_norm is not None:
|
| 185 |
+
score = 0.80 * score + 0.20 * ce_norm[i]
|
| 186 |
+
merged.append((score, d))
|
| 187 |
+
|
| 188 |
+
merged.sort(key=lambda x: x[0], reverse=True)
|
| 189 |
+
top = [d for _s, d in merged[:k_final]]
|
| 190 |
+
return top
|
| 191 |
|
| 192 |
|
| 193 |
+
def _build_context_from_docs(docs: List[Dict], query: str, max_blocks: int = 4) -> Tuple[str, List[str]]:
|
| 194 |
+
"""
|
| 195 |
+
Build a compact CONTEXT section using best paragraphs from top docs.
|
| 196 |
+
Return (context_text, sources).
|
| 197 |
+
"""
|
| 198 |
+
blocks: List[str] = []
|
| 199 |
+
sources: List[str] = []
|
| 200 |
+
for i, d in enumerate(docs[:max_blocks]):
|
| 201 |
+
snip = _best_paragraphs(d["text"], query, max_chars=700)
|
| 202 |
+
src = d.get("source", f"kb:{i}")
|
| 203 |
+
blocks.append(f"[{i+1}] {snip}\n(source: {src})")
|
| 204 |
+
sources.append(src)
|
| 205 |
+
if not blocks:
|
| 206 |
+
return "", []
|
| 207 |
+
prelude = "CONTEXT (use only these facts; if missing, say you don't know):"
|
| 208 |
+
return prelude + "\n\n" + "\n\n".join(blocks), sources
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ----------------------------
|
| 212 |
+
# Service
|
| 213 |
+
# ----------------------------
|
| 214 |
class ChatService:
|
| 215 |
def __init__(self, settings: Settings):
|
| 216 |
self.settings = settings
|
|
|
|
| 219 |
fallback=settings.model.fallback,
|
| 220 |
provider=getattr(settings.model, "provider", None),
|
| 221 |
max_retries=2,
|
| 222 |
+
connect_timeout=10.0,
|
| 223 |
+
read_timeout=60.0,
|
| 224 |
)
|
| 225 |
+
# RAG (singleton)
|
| 226 |
self.retriever = get_retriever(settings)
|
| 227 |
|
| 228 |
+
# Optional cross-encoder (large; disable via env RAG_RERANK=false)
|
| 229 |
+
self.reranker = None
|
| 230 |
+
use_rerank = os.getenv("RAG_RERANK", "true").lower() in ("1", "true", "yes")
|
| 231 |
+
if use_rerank and CrossEncoder is not None:
|
| 232 |
+
try:
|
| 233 |
+
self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-2-v2")
|
| 234 |
+
logger.info("RAG cross-encoder reranker enabled.")
|
| 235 |
+
except Exception as e:
|
| 236 |
+
logger.warning("Reranker disabled: %s", e)
|
| 237 |
+
|
| 238 |
+
# ---------- RAG core ----------
|
| 239 |
+
def _retrieve_best(self, query: str) -> Tuple[str, List[str]]:
|
| 240 |
+
"""
|
| 241 |
+
Retrieve many, rerank, and build a compact, high-signal CONTEXT.
|
| 242 |
+
Returns (context_text, sources).
|
| 243 |
+
"""
|
| 244 |
if not self.retriever:
|
| 245 |
return "", []
|
| 246 |
+
|
| 247 |
+
expanded = _expand_query(query)
|
| 248 |
+
# Retrieve a wider candidate pool, then rerank.
|
| 249 |
+
k_base = max(4, int(self.settings.rag.top_k) * 5)
|
| 250 |
+
try:
|
| 251 |
+
cands = self.retriever.retrieve(expanded, k=k_base) # [{'text','source','score'}]
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logger.warning("Retriever failed (%s); falling back to LLM-only.", e)
|
| 254 |
+
return "", []
|
| 255 |
+
|
| 256 |
+
if not cands:
|
| 257 |
return "", []
|
| 258 |
+
|
| 259 |
+
top = _rerank_docs(cands, query, k_final=max(3, self.settings.rag.top_k), reranker=self.reranker)
|
| 260 |
+
ctx, sources = _build_context_from_docs(top, query, max_blocks=max(3, self.settings.rag.top_k))
|
| 261 |
+
return ctx, sources
|
| 262 |
|
| 263 |
def _augment(self, query: str) -> Tuple[str, List[str]]:
|
| 264 |
"""
|
| 265 |
Build the final user message (with optional CONTEXT) and return sources.
|
| 266 |
"""
|
| 267 |
+
ctx, sources = self._retrieve_best(query)
|
|
|
|
|
|
|
| 268 |
if ctx:
|
| 269 |
+
user_msg = (
|
| 270 |
+
f"{ctx}\n\n"
|
| 271 |
+
"Based only on the context above, answer the question succinctly.\n"
|
| 272 |
+
f"Question: {query}\nAnswer:"
|
| 273 |
+
)
|
| 274 |
else:
|
| 275 |
+
user_msg = query # LLM-only fallback
|
| 276 |
+
return user_msg, sources
|
|
|
|
|
|
|
| 277 |
|
| 278 |
+
# ---------- Non-stream ----------
|
| 279 |
def answer_with_sources(self, query: str) -> Tuple[str, List[str]]:
|
| 280 |
user_msg, sources = self._augment(query)
|
| 281 |
text = self.client.chat_nonstream(
|
| 282 |
+
SYSTEM_PROMPT,
|
| 283 |
+
user_msg,
|
| 284 |
max_tokens=self.settings.model.max_new_tokens,
|
| 285 |
temperature=self.settings.model.temperature,
|
| 286 |
)
|
| 287 |
return text, sources
|
| 288 |
|
| 289 |
+
# ---------- Stream ----------
|
| 290 |
def stream_answer(self, query: str):
|
| 291 |
user_msg, _ = self._augment(query)
|
| 292 |
+
# SYNC generator yielding token deltas; router wraps in SSE
|
| 293 |
return self.client.chat_stream(
|
| 294 |
+
SYSTEM_PROMPT,
|
| 295 |
+
user_msg,
|
| 296 |
max_tokens=self.settings.model.max_new_tokens,
|
| 297 |
temperature=self.settings.model.temperature,
|
| 298 |
+
)
|
configs/rag_sources.yaml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Where to pull documentation from when building the RAG knowledge base.
|
| 2 |
+
# You can add/remove repos here; the builder will respect these sources.
|
| 3 |
+
|
| 4 |
+
github:
|
| 5 |
+
# 1) Explicit repos (stable)
|
| 6 |
+
repos:
|
| 7 |
+
- owner: agent-matrix
|
| 8 |
+
name: matrix-cli
|
| 9 |
+
branch: master
|
| 10 |
+
docs_paths: ["docs"] # folders to harvest (recursive)
|
| 11 |
+
include_readme: true
|
| 12 |
+
- owner: agent-matrix
|
| 13 |
+
name: matrix-python-sdk
|
| 14 |
+
branch: master
|
| 15 |
+
docs_paths: ["docs"]
|
| 16 |
+
include_readme: true
|
| 17 |
+
- owner: agent-matrix
|
| 18 |
+
name: matrixlink
|
| 19 |
+
branch: master
|
| 20 |
+
docs_paths: ["docs"]
|
| 21 |
+
include_readme: true
|
| 22 |
+
- owner: agent-matrix
|
| 23 |
+
name: matrix-hub
|
| 24 |
+
branch: master
|
| 25 |
+
docs_paths: ["docs"]
|
| 26 |
+
include_readme: true
|
| 27 |
+
|
| 28 |
+
# 2) Optionally scan an entire org for repos (README + docs/ if present)
|
| 29 |
+
# Comment out if you want only the explicit list above.
|
| 30 |
+
orgs:
|
| 31 |
+
- agent-matrix
|
| 32 |
+
|
| 33 |
+
# Local content in THIS repo (optional but recommended)
|
| 34 |
+
local:
|
| 35 |
+
paths:
|
| 36 |
+
- docs # everything under /docs
|
| 37 |
+
- README.md # root readme
|
| 38 |
+
glob: "**/*.md" # or "**/*.{md,mdx,txt}"
|
| 39 |
+
|
| 40 |
+
# Extra public URLs to pull (optional)
|
| 41 |
+
urls: []
|
data/kb.jsonl
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
CHANGED
|
@@ -18,3 +18,8 @@ pytest
|
|
| 18 |
ruff
|
| 19 |
mypy
|
| 20 |
pytest-asyncio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
ruff
|
| 19 |
mypy
|
| 20 |
pytest-asyncio
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
requests>=2.32.0
|
| 24 |
+
beautifulsoup4>=4.12.3 # only used if you later add generic HTML URLs
|
| 25 |
+
PyYAML>=6.0.1
|
scripts/build_kb.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Builds/refreshes the local RAG KB (data/kb.jsonl) from GitHub + local docs.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/build_kb.py --config configs/rag_sources.yaml --out data/kb.jsonl
|
| 7 |
+
python scripts/build_kb.py --config ... --out ... --force
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
import argparse
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
# --- Ensure THIS repo is first on sys.path (avoid clashing 'app' packages) ---
|
| 18 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 19 |
+
sys.path.insert(0, str(ROOT))
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger("build_kb")
|
| 22 |
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
| 23 |
+
|
| 24 |
+
# Import the builder from this project
|
| 25 |
+
try:
|
| 26 |
+
from app.core.rag.build import build_kb_from_config, ensure_kb # type: ignore
|
| 27 |
+
except Exception as e: # pragma: no cover
|
| 28 |
+
logger.error("Failed importing KB builder from app.core.rag.build: %s", e)
|
| 29 |
+
logger.error("Make sure you're running from the project root and PYTHONPATH includes '.'.")
|
| 30 |
+
sys.exit(2)
|
| 31 |
+
|
| 32 |
+
def main() -> int:
|
| 33 |
+
p = argparse.ArgumentParser()
|
| 34 |
+
p.add_argument("--config", required=True, help="Path to configs/rag_sources.yaml")
|
| 35 |
+
p.add_argument("--out", required=True, help="Output JSONL file, e.g., data/kb.jsonl")
|
| 36 |
+
p.add_argument("--force", action="store_true", help="Delete output file first, then rebuild")
|
| 37 |
+
args = p.parse_args()
|
| 38 |
+
|
| 39 |
+
out_path = Path(args.out)
|
| 40 |
+
if args.force and out_path.exists():
|
| 41 |
+
logger.info("Removing existing %s", out_path)
|
| 42 |
+
out_path.unlink()
|
| 43 |
+
|
| 44 |
+
# If you want a one-liner that skips if exists, use ensure_kb:
|
| 45 |
+
# created = ensure_kb(out_jsonl=args.out, config_path=args.config, skip_if_exists=True)
|
| 46 |
+
# logger.info("KB %s at %s", "ready" if created else "unchanged", args.out)
|
| 47 |
+
|
| 48 |
+
# Otherwise, always (re)build:
|
| 49 |
+
n = build_kb_from_config(config_path=args.config, out_jsonl=args.out)
|
| 50 |
+
logger.info("Wrote %d records to %s", n, args.out)
|
| 51 |
+
return 0
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
raise SystemExit(main())
|