ruslanmv commited on
Commit
215df55
·
1 Parent(s): 6338f31
.gitignore CHANGED
@@ -32,3 +32,5 @@ Thumbs.db
32
  # RAG index files
33
  .faiss/
34
  /backup
 
 
 
32
  # RAG index files
33
  .faiss/
34
  /backup
35
+ copy *.*
36
+ * copy.*
Makefile CHANGED
@@ -55,6 +55,10 @@ help:
55
  @printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "run" "Run uvicorn (PORT=$(PORT))"
56
  @printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "run-hot" "Run with --reload"
57
  @echo
 
 
 
 
58
  @echo "$(BRIGHT_GREEN)Docker$(RESET)"
59
  @printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "docker-build" "Build local image ($(IMG_NAME))"
60
  @printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "docker-run" "Run local container (maps $(PORT))"
@@ -100,6 +104,15 @@ run: install
100
  run-hot: install
101
  @PORT=$(PORT) $(VENV_DIR)/bin/uvicorn $(APP_MODULE) --host 0.0.0.0 --port $(PORT) --reload
102
 
 
 
 
 
 
 
 
 
 
103
  # ---------------------------------------------------------------------------
104
  # Docker
105
  # ---------------------------------------------------------------------------
@@ -121,4 +134,4 @@ space-url:
121
  clean:
122
  @rm -rf .venv __pycache__ .pytest_cache .ruff_cache .mypy_cache dist build *.egg-info
123
 
124
- .PHONY: help venv install lint fmt test run run-hot docker-build docker-run space-url clean
 
55
  @printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "run" "Run uvicorn (PORT=$(PORT))"
56
  @printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "run-hot" "Run with --reload"
57
  @echo
58
+ @echo "$(BRIGHT_GREEN)RAG / Knowledge Base$(RESET)"
59
+ @printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "kb" "Build/refresh KB from GitHub + local docs (writes data/kb.jsonl)"
60
+ @printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "kb-force" "Force rebuild KB (deletes existing data/kb.jsonl)"
61
+ @echo
62
  @echo "$(BRIGHT_GREEN)Docker$(RESET)"
63
  @printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "docker-build" "Build local image ($(IMG_NAME))"
64
  @printf " $(BRIGHT_GREEN)%-22s$(RESET) $(DIM_GREEN)%s$(RESET)\n" "docker-run" "Run local container (maps $(PORT))"
 
104
  run-hot: install
105
  @PORT=$(PORT) $(VENV_DIR)/bin/uvicorn $(APP_MODULE) --host 0.0.0.0 --port $(PORT) --reload
106
 
107
+ # ---------------------------------------------------------------------------
108
+ # RAG / Knowledge Base
109
+ # ---------------------------------------------------------------------------
110
+ kb: install
111
+ @PYTHONPATH=. $(PYTHON) scripts/build_kb.py --config configs/rag_sources.yaml --out data/kb.jsonl
112
+
113
+ kb-force: install
114
+ @rm -f data/kb.jsonl && PYTHONPATH=. $(PYTHON) scripts/build_kb.py --config configs/rag_sources.yaml --out data/kb.jsonl
115
+
116
  # ---------------------------------------------------------------------------
117
  # Docker
118
  # ---------------------------------------------------------------------------
 
134
  clean:
135
  @rm -rf .venv __pycache__ .pytest_cache .ruff_cache .mypy_cache dist build *.egg-info
136
 
137
+ .PHONY: help venv install lint fmt test run run-hot kb kb-force docker-build docker-run space-url clean
app/core/rag/build.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json, os, re, time, math, logging
3
+ from pathlib import Path
4
+ from typing import Dict, List, Iterable, Tuple, Optional
5
+
6
+ import yaml
7
+ import requests
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+ # -------------------------
12
+ # Text cleaning & chunking
13
+ # -------------------------
14
+
15
+ _MD_FRONTMATTER = re.compile(r"^---\s*\n.*?\n---\s*\n", re.DOTALL)
16
+
17
+ def normalize_text(text: str) -> str:
18
+ lines = [ln.strip() for ln in text.splitlines()]
19
+ cleaned = []
20
+ for ln in lines:
21
+ if not ln:
22
+ continue
23
+ if sum(ch.isalnum() for ch in ln) < 3:
24
+ continue
25
+ cleaned.append(ln)
26
+ s = "\n".join(cleaned)
27
+ s = re.sub(r"\n{3,}", "\n\n", s)
28
+ return s.strip()
29
+
30
+ def md_to_text(md: str) -> str:
31
+ md = re.sub(_MD_FRONTMATTER, "", md)
32
+ md = re.sub(r"```.*?```", "", md, flags=re.DOTALL) # drop fenced code
33
+ md = re.sub(r"!\[[^\]]*\]\([^)]+\)", "", md) # drop images
34
+ md = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md) # links -> label
35
+ md = re.sub(r"^\s{0,3}#{1,6}\s*", "", md, flags=re.MULTILINE)
36
+ md = md.replace("`", "")
37
+ md = re.sub(r"^\s*[-*+]\s+", "• ", md, flags=re.MULTILINE)
38
+ md = re.sub(r"^\s*>\s?", "", md, flags=re.MULTILINE)
39
+ return normalize_text(md)
40
+
41
+ def chunk_text(text: str, max_chars: int = 800, overlap: int = 120) -> List[str]:
42
+ paras = [p.strip() for p in text.split("\n\n") if p.strip()]
43
+ out: List[str] = []
44
+ buf = ""
45
+ for p in paras:
46
+ if len(p) > max_chars:
47
+ i = 0
48
+ while i < len(p):
49
+ j = min(i + max_chars, len(p))
50
+ out.append(p[i:j])
51
+ i = j - overlap if j - overlap > i else j
52
+ continue
53
+ if len(buf) + 2 + len(p) <= max_chars:
54
+ buf = (buf + "\n\n" + p) if buf else p
55
+ else:
56
+ if buf:
57
+ out.append(buf)
58
+ buf = p
59
+ if buf:
60
+ out.append(buf)
61
+ return out
62
+
63
+ def write_jsonl(records: Iterable[Dict], out_path: Path) -> None:
64
+ out_path.parent.mkdir(parents=True, exist_ok=True)
65
+ with out_path.open("w", encoding="utf-8") as f:
66
+ for rec in records:
67
+ f.write(json.dumps(rec, ensure_ascii=False) + "\n")
68
+
69
+ # -------------------------
70
+ # GitHub API helpers
71
+ # -------------------------
72
+
73
+ def gh_session() -> requests.Session:
74
+ s = requests.Session()
75
+ s.headers.update({
76
+ "Accept": "application/vnd.github+json",
77
+ "User-Agent": "matrix-ai-rag-builder/1.0",
78
+ })
79
+ tok = os.getenv("GITHUB_TOKEN")
80
+ if tok:
81
+ s.headers["Authorization"] = f"Bearer {tok}"
82
+ return s
83
+
84
+ def gh_get_json(url: str, sess: requests.Session, max_retries: int = 3) -> Dict | List:
85
+ backoff = 1.0
86
+ for attempt in range(max_retries):
87
+ r = sess.get(url, timeout=25)
88
+ if r.status_code == 403 and "rate limit" in r.text.lower():
89
+ log.warning("GitHub rate-limited; sleeping %.1fs", backoff)
90
+ time.sleep(backoff)
91
+ backoff = min(backoff * 2, 30)
92
+ continue
93
+ r.raise_for_status()
94
+ return r.json()
95
+ r.raise_for_status()
96
+ return {}
97
+
98
+ def gh_list_org_repos(org: str, sess: requests.Session) -> List[Dict]:
99
+ repos: List[Dict] = []
100
+ page = 1
101
+ while True:
102
+ url = f"https://api.github.com/orgs/{org}/repos?per_page=100&page={page}"
103
+ js = gh_get_json(url, sess)
104
+ if not js:
105
+ break
106
+ repos.extend(js)
107
+ if len(js) < 100:
108
+ break
109
+ page += 1
110
+ return repos
111
+
112
+ def gh_list_tree(owner: str, repo: str, branch: str, sess: requests.Session) -> List[Dict]:
113
+ url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
114
+ js = gh_get_json(url, sess)
115
+ return js.get("tree", []) if isinstance(js, dict) else []
116
+
117
+ def gh_fetch_raw(owner: str, repo: str, branch: str, path: str, sess: requests.Session) -> Optional[str]:
118
+ raw_url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
119
+ r = sess.get(raw_url, timeout=25)
120
+ if r.status_code == 404 and branch == "main": # try master fallback
121
+ raw_url = f"https://raw.githubusercontent.com/{owner}/{repo}/master/{path}"
122
+ r = sess.get(raw_url, timeout=25)
123
+ if r.status_code == 200:
124
+ return r.text
125
+ return None
126
+
127
+ # -------------------------
128
+ # Builders
129
+ # -------------------------
130
+
131
+ def ingest_github_repo(owner: str, name: str, branch: str, docs_paths: List[str],
132
+ include_readme: bool, exts: Tuple[str,...] = (".md",".mdx",".txt")) -> List[Tuple[str,str]]:
133
+ sess = gh_session()
134
+ out: List[Tuple[str,str]] = []
135
+
136
+ # README
137
+ if include_readme:
138
+ for candidate in ("README.md", "readme.md", "README.MD"):
139
+ t = gh_fetch_raw(owner, name, branch, candidate, sess)
140
+ if t:
141
+ out.append((f"github:{owner}/{name}/{candidate}", md_to_text(t)))
142
+ break
143
+
144
+ # Tree -> docs paths
145
+ tree = gh_list_tree(owner, name, branch, sess)
146
+ if not tree:
147
+ return out
148
+
149
+ wanted_dirs = [p.strip("/").lower() for p in docs_paths]
150
+ for entry in tree:
151
+ if entry.get("type") != "blob":
152
+ continue
153
+ path = entry.get("path", "")
154
+ lower = path.lower()
155
+ if not lower.endswith(exts):
156
+ continue
157
+ if any(lower.startswith(d + "/") for d in wanted_dirs):
158
+ t = gh_fetch_raw(owner, name, branch, path, sess)
159
+ if not t:
160
+ continue
161
+ txt = md_to_text(t) if lower.endswith((".md",".mdx")) else normalize_text(t)
162
+ if txt:
163
+ out.append((f"github:{owner}/{name}/{path}", txt))
164
+ return out
165
+
166
+ def ingest_github_sources(cfg: Dict) -> List[Tuple[str,str]]:
167
+ out: List[Tuple[str,str]] = []
168
+ gh = cfg.get("github") or {}
169
+ sess = gh_session()
170
+
171
+ # explicit repos
172
+ for repo in (gh.get("repos") or []):
173
+ owner = repo["owner"]
174
+ name = repo["name"]
175
+ branch = repo.get("branch", "main")
176
+ docs_paths = repo.get("docs_paths", ["docs"])
177
+ include_readme = bool(repo.get("include_readme", True))
178
+ out.extend(ingest_github_repo(owner, name, branch, docs_paths, include_readme))
179
+
180
+ # whole org scan (README + docs/)
181
+ for org in (gh.get("orgs") or []):
182
+ try:
183
+ repos = gh_list_org_repos(org, sess)
184
+ except Exception as e:
185
+ log.warning("Failed to list org %s: %s", org, e)
186
+ continue
187
+ for r in repos:
188
+ owner = r["owner"]["login"]
189
+ name = r["name"]
190
+ default_branch = r.get("default_branch", "main")
191
+ # README + docs/
192
+ out.extend(ingest_github_repo(owner, name, default_branch, ["docs"], include_readme=True))
193
+ return out
194
+
195
+ def ingest_local_sources(cfg: Dict) -> List[Tuple[str,str]]:
196
+ out: List[Tuple[str,str]] = []
197
+ local = cfg.get("local") or {}
198
+ paths = local.get("paths") or []
199
+ glob_pat = local.get("glob", "**/*.md")
200
+ for p in paths:
201
+ fp = Path(p)
202
+ if fp.is_file():
203
+ try:
204
+ raw = fp.read_text(encoding="utf-8", errors="ignore")
205
+ txt = md_to_text(raw) if fp.suffix.lower() in {".md",".mdx"} else normalize_text(raw)
206
+ if txt:
207
+ out.append((str(fp), txt))
208
+ except Exception as e:
209
+ log.warning("Failed reading %s: %s", fp, e)
210
+ elif fp.is_dir():
211
+ for f in fp.rglob(glob_pat):
212
+ try:
213
+ raw = f.read_text(encoding="utf-8", errors="ignore")
214
+ txt = md_to_text(raw) if f.suffix.lower() in {".md",".mdx"} else normalize_text(raw)
215
+ if txt:
216
+ out.append((str(f), txt))
217
+ except Exception as e:
218
+ log.warning("Failed reading %s: %s", f, e)
219
+ return out
220
+
221
+ def build_kb_from_config(config_path: str = "configs/rag_sources.yaml",
222
+ out_jsonl: str = "data/kb.jsonl",
223
+ max_chars: int = 800,
224
+ overlap: int = 120,
225
+ minlen: int = 200,
226
+ dedupe: bool = True) -> int:
227
+ cfg: Dict = {}
228
+ p = Path(config_path)
229
+ if p.exists():
230
+ cfg = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
231
+ else:
232
+ log.warning("rag_sources.yaml not found at %s (using defaults)", p)
233
+
234
+ records: List[Dict] = []
235
+
236
+ # GitHub
237
+ try:
238
+ gh_docs = ingest_github_sources(cfg)
239
+ for src, text in gh_docs:
240
+ for chunk in chunk_text(text, max_chars, overlap):
241
+ if len(chunk) >= minlen:
242
+ records.append({"text": chunk, "source": src})
243
+ except Exception as e:
244
+ log.warning("GitHub ingest failed: %s", e)
245
+
246
+ # Local
247
+ try:
248
+ loc_docs = ingest_local_sources(cfg)
249
+ for src, text in loc_docs:
250
+ for chunk in chunk_text(text, max_chars, overlap):
251
+ if len(chunk) >= minlen:
252
+ records.append({"text": chunk, "source": src})
253
+ except Exception as e:
254
+ log.warning("Local ingest failed: %s", e)
255
+
256
+ # URLs (optional)
257
+ for url in (cfg.get("urls") or []):
258
+ try:
259
+ r = requests.get(url, timeout=25)
260
+ r.raise_for_status()
261
+ txt = normalize_text(r.text)
262
+ for chunk in chunk_text(txt, max_chars, overlap):
263
+ if len(chunk) >= minlen:
264
+ records.append({"text": chunk, "source": url})
265
+ except Exception as e:
266
+ log.warning("URL ingest failed for %s: %s", url, e)
267
+
268
+ if dedupe:
269
+ seen = set()
270
+ deduped: List[Dict] = []
271
+ for rec in records:
272
+ h = hash(rec["text"])
273
+ if h in seen:
274
+ continue
275
+ seen.add(h)
276
+ deduped.append(rec)
277
+ records = deduped
278
+
279
+ if not records:
280
+ log.warning("No KB records produced.")
281
+ return 0
282
+
283
+ out_path = Path(out_jsonl)
284
+ write_jsonl(records, out_path)
285
+ log.info("Wrote %d chunks to %s", len(records), out_path)
286
+ return len(records)
287
+
288
+ def ensure_kb(out_jsonl: str = "data/kb.jsonl",
289
+ config_path: str = "configs/rag_sources.yaml",
290
+ skip_if_exists: bool = True) -> bool:
291
+ """
292
+ If kb.jsonl exists -> return True.
293
+ Else -> build from config and return True on success.
294
+ """
295
+ out = Path(out_jsonl)
296
+ if skip_if_exists and out.exists() and out.stat().st_size > 0:
297
+ log.info("KB already present at %s (skipping build)", out)
298
+ return True
299
+ n = build_kb_from_config(config_path=config_path, out_jsonl=out_jsonl)
300
+ return n > 0
app/main.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from __future__ import annotations
2
 
3
  import logging
@@ -9,19 +10,16 @@ from typing import Any, Dict
9
  from fastapi import FastAPI
10
  from fastapi.responses import RedirectResponse
11
 
12
- # --- ADDED: Import dependencies needed for pre-loading ---
13
- from .deps import get_settings
14
- from .services.chat_service import get_retriever
15
-
16
- # -----------------------------------------------------------------------------
17
- # Early: load .env (so HF_TOKEN, ADMIN_TOKEN, etc. are available locally)
18
- # -----------------------------------------------------------------------------
19
  def _load_env_file(paths: list[str]) -> None:
20
- """Load environment variables from the first existing path in `paths`.
21
- Prefer python-dotenv if present; otherwise use a tiny fallback parser."""
 
 
 
22
  logger = logging.getLogger("uvicorn.error")
23
 
24
- # 1) Try python-dotenv (best)
25
  try:
26
  from dotenv import load_dotenv # type: ignore
27
  for p in paths:
@@ -32,7 +30,7 @@ def _load_env_file(paths: list[str]) -> None:
32
  logger.info("No .env file found in %s (skipping)", paths)
33
  return
34
  except Exception:
35
- # 2) Fallback: simple parser
36
  for p in paths:
37
  if not os.path.exists(p):
38
  continue
@@ -53,7 +51,7 @@ def _load_env_file(paths: list[str]) -> None:
53
  val.startswith("'") and val.endswith("'")
54
  ):
55
  val = val[1:-1]
56
- # do not clobber existing env (Space Secrets)
57
  os.environ.setdefault(key, val)
58
  logger.info("Loaded environment from %s (fallback parser)", p)
59
  return
@@ -62,13 +60,17 @@ def _load_env_file(paths: list[str]) -> None:
62
 
63
  logger.info("No .env loaded (none found / parsers failed)")
64
 
65
- # Try typical locations for local dev. HF Spaces will ignore this and use Secrets.
66
  _load_env_file([".env", "configs/.env", ".env.local", "configs/.env.local"])
67
 
68
 
69
- # -----------------------------------------------------------------------------
70
- # Middlewares
71
- # -----------------------------------------------------------------------------
 
 
 
 
72
  try:
73
  from .middleware import attach_middlewares # singular
74
  except Exception:
@@ -80,11 +82,11 @@ except Exception:
80
  "attach_middlewares not found; continuing without custom middlewares."
81
  )
82
 
83
- # -----------------------------------------------------------------------------
84
- # Routers
85
- # -----------------------------------------------------------------------------
86
  from .routers import health, plan, chat
87
 
 
88
  try:
89
  from .ui import router as ui_router # type: ignore
90
  HAS_UI = True
@@ -105,12 +107,27 @@ async def lifespan(app: FastAPI):
105
  app.state.started_at = time.time()
106
  app.state.version = os.getenv("APP_VERSION", "1.0.0")
107
 
108
- # --- ADDED: Pre-load the RAG model and index on startup ---
109
  logger = logging.getLogger("uvicorn.error")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  logger.info("Warming up RAG retriever...")
111
  get_retriever(get_settings())
112
  logger.info("RAG retriever is ready.")
113
-
 
114
  hf_token_present = bool(os.getenv("HF_TOKEN"))
115
  logger.info(
116
  "matrix-ai starting (version=%s, port=%s, hf_token_present=%s)",
@@ -118,13 +135,12 @@ async def lifespan(app: FastAPI):
118
  os.getenv("PORT", "7860"),
119
  "yes" if hf_token_present else "no",
120
  )
 
121
  try:
122
  yield
123
  finally:
124
  uptime = time.time() - getattr(app.state, "started_at", time.time())
125
- logger.info(
126
- "matrix-ai shutting down (uptime=%.2fs)", uptime
127
- )
128
 
129
 
130
  def create_app() -> FastAPI:
@@ -138,14 +154,19 @@ def create_app() -> FastAPI:
138
  lifespan=lifespan,
139
  )
140
 
 
141
  attach_middlewares(app)
 
 
142
  app.include_router(health.router, tags=["Health"])
143
  app.include_router(plan.router, prefix="/v1", tags=["Planning"])
144
  app.include_router(chat.router, prefix="/v1", tags=["Chat"])
145
 
 
146
  if HAS_UI:
147
  app.include_router(ui_router, tags=["UI"])
148
  else:
 
149
  @app.get("/", include_in_schema=False)
150
  async def root() -> Dict[str, Any]:
151
  return {
@@ -155,9 +176,12 @@ def create_app() -> FastAPI:
155
  "docs": "/docs",
156
  "endpoints": {"plan": "/v1/plan", "chat": "/v1/chat", "healthz": "/healthz"},
157
  }
 
158
  @app.get("/home", include_in_schema=False)
159
  async def home_redirect():
160
  return RedirectResponse(url="/docs", status_code=302)
 
161
  return app
162
 
163
- app = create_app()
 
 
1
+ # app/main.py
2
  from __future__ import annotations
3
 
4
  import logging
 
10
  from fastapi import FastAPI
11
  from fastapi.responses import RedirectResponse
12
 
13
+ # ---- Early env load (HF_TOKEN, ADMIN_TOKEN, GITHUB_TOKEN, etc.) ----
 
 
 
 
 
 
14
  def _load_env_file(paths: list[str]) -> None:
15
+ """
16
+ Load environment variables from the first existing path in `paths`.
17
+ Prefer python-dotenv if present; otherwise use a tiny fallback parser.
18
+ Does not override pre-existing env vars (e.g., Space Secrets).
19
+ """
20
  logger = logging.getLogger("uvicorn.error")
21
 
22
+ # 1) Try python-dotenv
23
  try:
24
  from dotenv import load_dotenv # type: ignore
25
  for p in paths:
 
30
  logger.info("No .env file found in %s (skipping)", paths)
31
  return
32
  except Exception:
33
+ # 2) Fallback minimal parser
34
  for p in paths:
35
  if not os.path.exists(p):
36
  continue
 
51
  val.startswith("'") and val.endswith("'")
52
  ):
53
  val = val[1:-1]
54
+ # do not clobber existing env (e.g., HF Secrets)
55
  os.environ.setdefault(key, val)
56
  logger.info("Loaded environment from %s (fallback parser)", p)
57
  return
 
60
 
61
  logger.info("No .env loaded (none found / parsers failed)")
62
 
63
+ # Try common local locations. HF Spaces will rely on Secrets instead.
64
  _load_env_file([".env", "configs/.env", ".env.local", "configs/.env.local"])
65
 
66
 
67
+ # ---- RAG bootstrap & warm-up ----
68
+ from .deps import get_settings
69
+ from .services.chat_service import get_retriever
70
+ from .core.rag.build import ensure_kb
71
+
72
+
73
+ # ---- Middlewares ----
74
  try:
75
  from .middleware import attach_middlewares # singular
76
  except Exception:
 
82
  "attach_middlewares not found; continuing without custom middlewares."
83
  )
84
 
85
+
86
+ # ---- Routers ----
 
87
  from .routers import health, plan, chat
88
 
89
+ # Optional UI bundle (/, /chat, /dev)
90
  try:
91
  from .ui import router as ui_router # type: ignore
92
  HAS_UI = True
 
107
  app.state.started_at = time.time()
108
  app.state.version = os.getenv("APP_VERSION", "1.0.0")
109
 
 
110
  logger = logging.getLogger("uvicorn.error")
111
+
112
+ # 1) Build KB on first boot (skips if already present)
113
+ try:
114
+ if ensure_kb(
115
+ out_jsonl="data/kb.jsonl",
116
+ config_path="configs/rag_sources.yaml",
117
+ skip_if_exists=True,
118
+ ):
119
+ logger.info("KB ready at data/kb.jsonl")
120
+ else:
121
+ logger.warning("KB build produced no records; running LLM-only.")
122
+ except Exception as e:
123
+ logger.warning("KB build failed (%s); running LLM-only.", e)
124
+
125
+ # 2) Warm up RAG retriever (indexes data/kb.jsonl if present)
126
  logger.info("Warming up RAG retriever...")
127
  get_retriever(get_settings())
128
  logger.info("RAG retriever is ready.")
129
+
130
+ # 3) Boot log
131
  hf_token_present = bool(os.getenv("HF_TOKEN"))
132
  logger.info(
133
  "matrix-ai starting (version=%s, port=%s, hf_token_present=%s)",
 
135
  os.getenv("PORT", "7860"),
136
  "yes" if hf_token_present else "no",
137
  )
138
+
139
  try:
140
  yield
141
  finally:
142
  uptime = time.time() - getattr(app.state, "started_at", time.time())
143
+ logger.info("matrix-ai shutting down (uptime=%.2fs)", uptime)
 
 
144
 
145
 
146
  def create_app() -> FastAPI:
 
154
  lifespan=lifespan,
155
  )
156
 
157
+ # Middlewares (gzip, CORS, rate-limit, req-logs, etc.)
158
  attach_middlewares(app)
159
+
160
+ # Core routers
161
  app.include_router(health.router, tags=["Health"])
162
  app.include_router(plan.router, prefix="/v1", tags=["Planning"])
163
  app.include_router(chat.router, prefix="/v1", tags=["Chat"])
164
 
165
+ # UI (/, /chat, /dev). Your ui.py already defines "/" → /chat
166
  if HAS_UI:
167
  app.include_router(ui_router, tags=["UI"])
168
  else:
169
+ # Minimal root so HF root probes pass even without UI
170
  @app.get("/", include_in_schema=False)
171
  async def root() -> Dict[str, Any]:
172
  return {
 
176
  "docs": "/docs",
177
  "endpoints": {"plan": "/v1/plan", "chat": "/v1/chat", "healthz": "/healthz"},
178
  }
179
+
180
  @app.get("/home", include_in_schema=False)
181
  async def home_redirect():
182
  return RedirectResponse(url="/docs", status_code=302)
183
+
184
  return app
185
 
186
+
187
+ app = create_app()
app/middleware.py CHANGED
@@ -1,30 +1,68 @@
1
  import time
2
  import logging
 
3
  from typing import Callable
4
  from fastapi import FastAPI, Request, Response
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from starlette.middleware.gzip import GZipMiddleware
7
- from pythonjsonlogger import jsonlogger
 
 
 
 
 
 
 
8
  from .deps import get_settings
9
  from .core.rate_limit import RateLimiter
10
  from .core.logging import add_trace_id
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Setup structured logging
13
  logger = logging.getLogger("matrix-ai")
14
  if not logger.handlers:
15
  logger.setLevel(logging.INFO)
16
  handler = logging.StreamHandler()
17
- formatter = jsonlogger.JsonFormatter(
18
- '%(asctime)s %(name)s %(levelname)s %(message)s %(trace_id)s'
19
- )
 
 
 
 
 
 
 
20
  handler.setFormatter(formatter)
21
  logger.addHandler(handler)
22
 
23
  _rate_limiter = RateLimiter()
24
 
25
- def attach_middlewares(app: FastAPI):
26
  """Attaches all required middlewares to the FastAPI app."""
 
 
27
  app.add_middleware(GZipMiddleware, minimum_size=512)
 
28
  app.add_middleware(
29
  CORSMiddleware,
30
  allow_origins=["*"],
@@ -35,20 +73,25 @@ def attach_middlewares(app: FastAPI):
35
 
36
  @app.middleware("http")
37
  async def rate_limit_and_log_middleware(request: Request, call_next: Callable):
 
38
  add_trace_id(request)
 
39
  settings = get_settings()
40
  client_ip = request.client.host if request.client else "unknown"
41
 
42
- if not _rate_limiter.allow(client_ip, request.url.path, settings.limits.rate_per_min):
 
 
 
43
  return Response(status_code=429, content="Rate limit exceeded")
44
 
45
  start_time = time.time()
46
  response = await call_next(request)
47
- process_time = (time.time() - start_time) * 1000
48
  response.headers["X-Process-Time-Ms"] = f"{process_time:.2f}"
49
 
50
  logger.info(
51
  f'"{request.method} {request.url.path}" {response.status_code}',
52
- extra={'trace_id': getattr(request.state, 'trace_id', 'N/A')}
53
  )
54
  return response
 
1
  import time
2
  import logging
3
+ import json
4
  from typing import Callable
5
  from fastapi import FastAPI, Request, Response
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from starlette.middleware.gzip import GZipMiddleware
8
+
9
+ # Try to import python-json-logger; fall back to a tiny JSON formatter if missing.
10
+ try:
11
+ from pythonjsonlogger import jsonlogger # type: ignore[import-not-found]
12
+ _HAS_PY_JSON_LOGGER = True
13
+ except Exception:
14
+ _HAS_PY_JSON_LOGGER = False
15
+
16
  from .deps import get_settings
17
  from .core.rate_limit import RateLimiter
18
  from .core.logging import add_trace_id
19
 
20
+ # ---- Fallback JSON formatter (if python-json-logger isn't available) ----
21
+ class _SimpleJsonFormatter(logging.Formatter):
22
+ def format(self, record: logging.LogRecord) -> str:
23
+ payload = {
24
+ "asctime": self.formatTime(record, "%Y-%m-%d %H:%M:%S"),
25
+ "name": record.name,
26
+ "levelname": record.levelname,
27
+ "message": record.getMessage(),
28
+ # We attach trace_id via logger.info(..., extra={"trace_id": "..."}).
29
+ "trace_id": getattr(record, "trace_id", None),
30
+ }
31
+ try:
32
+ return json.dumps(payload, ensure_ascii=False)
33
+ except Exception:
34
+ # Last-ditch plain log if JSON serialization ever fails
35
+ return (
36
+ f'{payload["asctime"]} {payload["name"]} {payload["levelname"]} '
37
+ f'{payload["message"]} trace_id={payload["trace_id"]}'
38
+ )
39
+
40
  # Setup structured logging
41
  logger = logging.getLogger("matrix-ai")
42
  if not logger.handlers:
43
  logger.setLevel(logging.INFO)
44
  handler = logging.StreamHandler()
45
+ if _HAS_PY_JSON_LOGGER:
46
+ # Same fields you had; python-json-logger builds JSON from this format string
47
+ formatter = jsonlogger.JsonFormatter(
48
+ "%(asctime)s %(name)s %(levelname)s %(message)s %(trace_id)s"
49
+ )
50
+ else:
51
+ formatter = _SimpleJsonFormatter()
52
+ logging.getLogger("uvicorn.error").warning(
53
+ "python-json-logger not found; using a minimal JSON formatter."
54
+ )
55
  handler.setFormatter(formatter)
56
  logger.addHandler(handler)
57
 
58
  _rate_limiter = RateLimiter()
59
 
60
+ def attach_middlewares(app: FastAPI) -> None:
61
  """Attaches all required middlewares to the FastAPI app."""
62
+ # NOTE: We keep GZip, but your SSE endpoints already set `Content-Encoding: identity`
63
+ # so they won't be buffered/compressed.
64
  app.add_middleware(GZipMiddleware, minimum_size=512)
65
+
66
  app.add_middleware(
67
  CORSMiddleware,
68
  allow_origins=["*"],
 
73
 
74
  @app.middleware("http")
75
  async def rate_limit_and_log_middleware(request: Request, call_next: Callable):
76
+ # Attach per-request trace id
77
  add_trace_id(request)
78
+
79
  settings = get_settings()
80
  client_ip = request.client.host if request.client else "unknown"
81
 
82
+ # Simple fixed-window limiter
83
+ if not _rate_limiter.allow(
84
+ client_ip, request.url.path, settings.limits.rate_per_min
85
+ ):
86
  return Response(status_code=429, content="Rate limit exceeded")
87
 
88
  start_time = time.time()
89
  response = await call_next(request)
90
+ process_time = (time.time() - start_time) * 1000.0
91
  response.headers["X-Process-Time-Ms"] = f"{process_time:.2f}"
92
 
93
  logger.info(
94
  f'"{request.method} {request.url.path}" {response.status_code}',
95
+ extra={"trace_id": getattr(request.state, "trace_id", "N/A")},
96
  )
97
  return response
app/routers/chat.py CHANGED
@@ -1,11 +1,11 @@
1
  from __future__ import annotations
2
 
3
  import json
4
- from typing import Any, AsyncIterator, List, Optional
5
 
6
  from fastapi import APIRouter, Depends, HTTPException, Query
7
  from pydantic import BaseModel, Field
8
- from starlette.concurrency import run_in_threadpool
9
  from starlette.responses import StreamingResponse
10
 
11
  from ..deps import get_settings
@@ -27,12 +27,16 @@ class ChatRequest(BaseModel):
27
  messages: Optional[List[ChatMessage]] = None
28
 
29
  def as_text(self) -> str:
30
- if self.query: return self.query
31
- if self.question: return self.question
32
- if self.prompt: return self.prompt
 
 
 
33
  if self.messages:
34
  for m in reversed(self.messages):
35
- if m.role.lower() == "user": return m.content
 
36
  return self.messages[-1].content
37
  raise ValueError("Body must include 'query'/'question'/'prompt' or 'messages'")
38
 
@@ -50,7 +54,7 @@ async def chat(req: ChatRequest, settings: Settings = Depends(get_settings)):
50
  raise HTTPException(status_code=422, detail=str(e))
51
  svc = ChatService(settings)
52
  try:
53
- # Run the blocking call in a thread pool to avoid freezing the server
54
  answer, sources = await run_in_threadpool(svc.answer_with_sources, text)
55
  return ChatResponse(answer=answer, sources=sources)
56
  except PermissionError as e:
@@ -63,7 +67,6 @@ async def chat(req: ChatRequest, settings: Settings = Depends(get_settings)):
63
  async def chat_get(query: str = Query(...), settings: Settings = Depends(get_settings)):
64
  svc = ChatService(settings)
65
  try:
66
- # Run the blocking call in a thread pool
67
  answer, sources = await run_in_threadpool(svc.answer_with_sources, query)
68
  return ChatResponse(answer=answer, sources=sources)
69
  except PermissionError as e:
@@ -72,7 +75,6 @@ async def chat_get(query: str = Query(...), settings: Settings = Depends(get_set
72
  raise HTTPException(status_code=502, detail=f"Inference error: {e}")
73
 
74
 
75
- # ---------- Streaming (SSE) ----------
76
  def _sse_line(obj: Any) -> str:
77
  payload = obj if isinstance(obj, str) else json.dumps(obj, ensure_ascii=False)
78
  return f"data: {payload}\n\n"
@@ -80,25 +82,29 @@ def _sse_line(obj: Any) -> str:
80
 
81
  @router.get("/chat/stream")
82
  async def chat_stream(query: str = Query(...), settings: Settings = Depends(get_settings)):
 
 
 
 
83
  svc = ChatService(settings)
84
 
85
- async def gen() -> AsyncIterator[str]:
86
- # Anti-buffer padding and initial ping
87
  yield ":" + (" " * 2048) + "\n\n"
 
88
  yield "event: ping\ndata: 0\n\n"
89
-
 
90
  try:
91
- # Run the blocking retrieval part in a thread pool, then stream the results
92
- stream_generator = await run_in_threadpool(svc.stream_answer, query)
93
- any_tokens = False
94
- for token in stream_generator:
95
  if token:
96
  any_tokens = True
97
  yield _sse_line({"delta": token})
98
-
99
  if not any_tokens:
100
  yield _sse_line({"delta": ""})
101
  yield _sse_line("[DONE]")
 
 
102
  except Exception as e:
103
  yield _sse_line({"error": str(e)})
104
 
@@ -108,7 +114,12 @@ async def chat_stream(query: str = Query(...), settings: Settings = Depends(get_
108
  "Connection": "keep-alive",
109
  "Content-Encoding": "identity",
110
  }
111
- return StreamingResponse(gen(), media_type="text/event-stream; charset=utf-8", headers=headers)
 
 
 
 
 
112
 
113
 
114
  @router.post("/chat/stream")
@@ -117,4 +128,4 @@ async def chat_stream_post(req: ChatRequest, settings: Settings = Depends(get_se
117
  q = req.as_text()
118
  except ValueError as e:
119
  raise HTTPException(status_code=422, detail=str(e))
120
- return await chat_stream(query=q, settings=settings)
 
1
  from __future__ import annotations
2
 
3
  import json
4
+ from typing import Any, Iterator, List, Optional
5
 
6
  from fastapi import APIRouter, Depends, HTTPException, Query
7
  from pydantic import BaseModel, Field
8
+ from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
9
  from starlette.responses import StreamingResponse
10
 
11
  from ..deps import get_settings
 
27
  messages: Optional[List[ChatMessage]] = None
28
 
29
  def as_text(self) -> str:
30
+ if self.query:
31
+ return self.query
32
+ if self.question:
33
+ return self.question
34
+ if self.prompt:
35
+ return self.prompt
36
  if self.messages:
37
  for m in reversed(self.messages):
38
+ if m.role.lower() == "user":
39
+ return m.content
40
  return self.messages[-1].content
41
  raise ValueError("Body must include 'query'/'question'/'prompt' or 'messages'")
42
 
 
54
  raise HTTPException(status_code=422, detail=str(e))
55
  svc = ChatService(settings)
56
  try:
57
+ # run blocking client in a threadpool
58
  answer, sources = await run_in_threadpool(svc.answer_with_sources, text)
59
  return ChatResponse(answer=answer, sources=sources)
60
  except PermissionError as e:
 
67
  async def chat_get(query: str = Query(...), settings: Settings = Depends(get_settings)):
68
  svc = ChatService(settings)
69
  try:
 
70
  answer, sources = await run_in_threadpool(svc.answer_with_sources, query)
71
  return ChatResponse(answer=answer, sources=sources)
72
  except PermissionError as e:
 
75
  raise HTTPException(status_code=502, detail=f"Inference error: {e}")
76
 
77
 
 
78
  def _sse_line(obj: Any) -> str:
79
  payload = obj if isinstance(obj, str) else json.dumps(obj, ensure_ascii=False)
80
  return f"data: {payload}\n\n"
 
82
 
83
  @router.get("/chat/stream")
84
  async def chat_stream(query: str = Query(...), settings: Settings = Depends(get_settings)):
85
+ """
86
+ SSE of token deltas. We iterate the sync streaming client in a threadpool
87
+ so the event loop stays free.
88
+ """
89
  svc = ChatService(settings)
90
 
91
+ def sync_stream() -> Iterator[str]:
92
+ # send anti-buffer padding + ping immediately
93
  yield ":" + (" " * 2048) + "\n\n"
94
+ yield "retry: 1500\n\n"
95
  yield "event: ping\ndata: 0\n\n"
96
+
97
+ any_tokens = False
98
  try:
99
+ for token in svc.stream_answer(query):
 
 
 
100
  if token:
101
  any_tokens = True
102
  yield _sse_line({"delta": token})
 
103
  if not any_tokens:
104
  yield _sse_line({"delta": ""})
105
  yield _sse_line("[DONE]")
106
+ except GeneratorExit:
107
+ return
108
  except Exception as e:
109
  yield _sse_line({"error": str(e)})
110
 
 
114
  "Connection": "keep-alive",
115
  "Content-Encoding": "identity",
116
  }
117
+ # iterate the sync generator in a threadpool (non-blocking for the loop)
118
+ return StreamingResponse(
119
+ iterate_in_threadpool(sync_stream()),
120
+ media_type="text/event-stream; charset=utf-8",
121
+ headers=headers,
122
+ )
123
 
124
 
125
  @router.post("/chat/stream")
 
128
  q = req.as_text()
129
  except ValueError as e:
130
  raise HTTPException(status_code=422, detail=str(e))
131
+ return await chat_stream(query=q, settings=settings)
app/services/chat_service.py CHANGED
@@ -1,9 +1,12 @@
 
1
  from __future__ import annotations
2
 
3
  import logging
4
  import os
 
 
5
  from pathlib import Path
6
- from typing import List, Tuple
7
 
8
  from ..core.config import Settings
9
  from ..core.inference.client import RouterRequestsClient
@@ -11,33 +14,203 @@ from ..core.rag.retriever import Retriever
11
 
12
  logger = logging.getLogger(__name__)
13
 
 
 
 
 
 
 
14
  SYSTEM_PROMPT = (
15
- "You are MATRIX-AI, a concise, helpful assistant for the Matrix EcoSystem. "
16
- "Answer clearly and briefly. If unsure, say so."
 
17
  )
18
 
19
- # --- Singleton instance for the expensive Retriever class ---
20
- _retriever_instance: Retriever | None = None
 
 
21
 
22
- def get_retriever(settings: Settings) -> Retriever | None:
23
- """Initializes and returns a single instance of the Retriever."""
24
  global _retriever_instance
25
  if _retriever_instance is not None:
26
  return _retriever_instance
27
 
28
  kb_path = os.getenv("RAG_KB_PATH", "data/kb.jsonl")
29
- try:
30
- if Path(kb_path).exists():
 
 
 
 
 
 
31
  _retriever_instance = Retriever(kb_path=kb_path, top_k=settings.rag.top_k)
32
  logger.info("RAG enabled with KB at %s (top_k=%d)", kb_path, settings.rag.top_k)
33
- else:
34
- logger.info("RAG KB not found at %s — running LLM-only.", kb_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  except Exception as e:
36
- logger.warning("RAG disabled (failed to initialize Retriever: %s)", e)
 
37
 
38
- return _retriever_instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  class ChatService:
42
  def __init__(self, settings: Settings):
43
  self.settings = settings
@@ -46,51 +219,80 @@ class ChatService:
46
  fallback=settings.model.fallback,
47
  provider=getattr(settings.model, "provider", None),
48
  max_retries=2,
 
 
49
  )
50
- # Get the singleton retriever instance
51
  self.retriever = get_retriever(settings)
52
 
53
- def _build_context(self, query: str) -> Tuple[str, List[str]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  if not self.retriever:
55
  return "", []
56
- docs = self.retriever.retrieve(query, self.settings.rag.top_k)
57
- if not docs:
 
 
 
 
 
 
 
 
 
58
  return "", []
59
- blocks = [f"[{i+1}] {d['text']} (source: {d['source']})" for i, d in enumerate(docs)]
60
- context = "CONTEXT (use only these facts; if missing, say you don't know):\n" + "\n\n".join(blocks)
61
- sources = [d["source"] for d in docs]
62
- return context, sources
63
 
64
  def _augment(self, query: str) -> Tuple[str, List[str]]:
65
  """
66
  Build the final user message (with optional CONTEXT) and return sources.
67
  """
68
- ctx, sources = self._build_context(query)
69
-
70
- # --- THIS IS THE CORRECTED PROMPT ---
71
  if ctx:
72
- # New, clearer instruction format
73
- augmented = f"{ctx}\n\nBased only on the context provided above, answer the following question.\nQuestion: {query}"
 
 
 
74
  else:
75
- # If no context, just pass the original query
76
- augmented = query
77
-
78
- return augmented, sources
79
 
80
- # Note: These methods are now called from a thread pool in the router
81
  def answer_with_sources(self, query: str) -> Tuple[str, List[str]]:
82
  user_msg, sources = self._augment(query)
83
  text = self.client.chat_nonstream(
84
- SYSTEM_PROMPT, user_msg,
 
85
  max_tokens=self.settings.model.max_new_tokens,
86
  temperature=self.settings.model.temperature,
87
  )
88
  return text, sources
89
 
 
90
  def stream_answer(self, query: str):
91
  user_msg, _ = self._augment(query)
 
92
  return self.client.chat_stream(
93
- SYSTEM_PROMPT, user_msg,
 
94
  max_tokens=self.settings.model.max_new_tokens,
95
  temperature=self.settings.model.temperature,
96
- )
 
1
+ # app/services/chat_service.py
2
  from __future__ import annotations
3
 
4
  import logging
5
  import os
6
+ import re
7
+ import threading
8
  from pathlib import Path
9
+ from typing import List, Tuple, Dict, Optional
10
 
11
  from ..core.config import Settings
12
  from ..core.inference.client import RouterRequestsClient
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
17
+ # --- Optional cross-encoder reranker (graceful fallback) ---
18
+ try:
19
+ from sentence_transformers import CrossEncoder # type: ignore
20
+ except Exception: # pragma: no cover
21
+ CrossEncoder = None # type: ignore
22
+
23
  SYSTEM_PROMPT = (
24
+ "You are MATRIX-AI, a concise, helpful assistant for the Matrix EcoSystem.\n"
25
+ "You MUST use the provided CONTEXT. If an answer is not supported by the context, say you don't know.\n"
26
+ "Prefer short, clear sentences. Include product/feature names exactly as written in the context.\n"
27
  )
28
 
29
+ # Thread-safe singleton retriever
30
+ _retriever_instance: Optional[Retriever] = None
31
+ _retriever_lock = threading.Lock()
32
+
33
 
34
+ def get_retriever(settings: Settings) -> Optional[Retriever]:
35
+ """Initialize and return a single Retriever instance (double-checked locking)."""
36
  global _retriever_instance
37
  if _retriever_instance is not None:
38
  return _retriever_instance
39
 
40
  kb_path = os.getenv("RAG_KB_PATH", "data/kb.jsonl")
41
+ if not Path(kb_path).exists():
42
+ logger.info("RAG KB not found at %s — running LLM-only.", kb_path)
43
+ return None
44
+
45
+ with _retriever_lock:
46
+ if _retriever_instance is not None:
47
+ return _retriever_instance
48
+ try:
49
  _retriever_instance = Retriever(kb_path=kb_path, top_k=settings.rag.top_k)
50
  logger.info("RAG enabled with KB at %s (top_k=%d)", kb_path, settings.rag.top_k)
51
+ except Exception as e:
52
+ logger.warning("RAG disabled (failed to initialize Retriever: %s)", e)
53
+ _retriever_instance = None
54
+ return _retriever_instance
55
+
56
+
57
+ # ----------------------------
58
+ # RAG utilities (ranking & snippets)
59
+ # ----------------------------
60
+
61
+ _ALIAS_TABLE: Dict[str, List[str]] = {
62
+ # canonical -> aliases
63
+ "matrixhub": ["matrix hub", "matrixhub", "hub api", "catalog", "registry", "cas"],
64
+ "mcp": ["model context protocol", "mcp", "manifest", "server manifest", "admin api"],
65
+ "agent-matrix": ["agent-matrix", "matrix agents", "matrix ecosystem", "matrix toolkit"],
66
+ }
67
+
68
+ _WORD_RE = re.compile(r"[A-Za-z0-9_]+")
69
+
70
+
71
+ def _normalize(text: str) -> List[str]:
72
+ return [t.lower() for t in _WORD_RE.findall(text)]
73
+
74
+
75
+ def _expand_query(q: str) -> str:
76
+ """Add domain aliases to help the embedding retrieve the right docs."""
77
+ ql = q.lower()
78
+ extras: List[str] = []
79
+ for canon, variants in _ALIAS_TABLE.items():
80
+ if any(v in ql for v in variants):
81
+ extras.extend([canon] + variants)
82
+ if extras:
83
+ return q + " | " + " ".join(sorted(set(extras)))
84
+ return q
85
+
86
+
87
+ def _keyword_overlap_score(query: str, text: str) -> float:
88
+ """Simple lexical grounding score: Jaccard over unique tokens (stopword-agnostic)."""
89
+ q_tokens = set(_normalize(query))
90
+ d_tokens = set(_normalize(text))
91
+ if not q_tokens or not d_tokens:
92
+ return 0.0
93
+ inter = len(q_tokens & d_tokens)
94
+ union = len(q_tokens | d_tokens)
95
+ return inter / max(1, union)
96
+
97
+
98
+ def _domain_boost(text: str) -> float:
99
+ t = text.lower()
100
+ boost = 0.0
101
+ for term in ("matrixhub", "hub api", "catalog", "mcp", "server manifest", "cas"):
102
+ if term in t:
103
+ boost += 0.05
104
+ return min(boost, 0.25)
105
+
106
+
107
+ def _best_paragraphs(text: str, query: str, max_chars: int = 700) -> str:
108
+ """
109
+ Split by blank lines and pick 1-2 best paragraphs by lexical overlap.
110
+ Keep it compact to avoid swamping the LLM.
111
+ """
112
+ paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
113
+ if not paras:
114
+ return text[:max_chars]
115
+
116
+ scored = [(p, _keyword_overlap_score(query, p)) for p in paras]
117
+ scored.sort(key=lambda x: x[1], reverse=True)
118
+ picked: List[str] = []
119
+ used = 0
120
+ for p, _s in scored[:4]:
121
+ if used >= max_chars:
122
+ break
123
+ picked.append(p)
124
+ used += len(p) + 2
125
+ if used >= max_chars or len(picked) >= 2:
126
+ break
127
+ return "\n".join(picked)
128
+
129
+
130
+ def _cross_encoder_scores(
131
+ model: Optional["CrossEncoder"],
132
+ query: str,
133
+ docs: List[Dict],
134
+ max_pairs: int = 50,
135
+ ) -> Optional[List[float]]:
136
+ if not model:
137
+ return None
138
+ try:
139
+ pairs = [(query, d["text"][:1200]) for d in docs[:max_pairs]]
140
+ return list(model.predict(pairs))
141
  except Exception as e:
142
+ logger.warning("Cross-encoder scoring failed; continuing without it (%s)", e)
143
+ return None
144
 
145
+
146
+ def _rerank_docs(
147
+ docs: List[Dict],
148
+ query: str,
149
+ k_final: int,
150
+ reranker: Optional["CrossEncoder"] = None,
151
+ ) -> List[Dict]:
152
+ """
153
+ Combine vector score, keyword overlap, domain boost, and optional cross-encoder.
154
+ Score = 0.55*vec + 0.35*lex + 0.10*boost (+ 0.20*ce if available, rescaled).
155
+ """
156
+ if not docs:
157
+ return []
158
+
159
+ # Normalize vector scores (cosine similarities 0..1-ish) to 0..1
160
+ vec_scores = [float(d.get("score", 0.0)) for d in docs]
161
+ if vec_scores:
162
+ vmin = min(vec_scores)
163
+ vmax = max(vec_scores)
164
+ rng = max(1e-6, (vmax - vmin))
165
+ vec_norm = [(v - vmin) / rng for v in vec_scores]
166
+ else:
167
+ vec_norm = [0.0] * len(docs)
168
+
169
+ lex_scores = [_keyword_overlap_score(query, d["text"]) for d in docs]
170
+ boosts = [_domain_boost(d["text"]) for d in docs]
171
+
172
+ ce_scores = _cross_encoder_scores(reranker, query, docs)
173
+ if ce_scores:
174
+ # Min-max normalize CE too
175
+ cmin, cmax = min(ce_scores), max(ce_scores)
176
+ crng = max(1e-6, (cmax - cmin))
177
+ ce_norm = [(c - cmin) / crng for c in ce_scores]
178
+ else:
179
+ ce_norm = None
180
+
181
+ merged: List[Tuple[float, Dict]] = []
182
+ for i, d in enumerate(docs):
183
+ score = 0.55 * vec_norm[i] + 0.35 * lex_scores[i] + 0.10 * boosts[i]
184
+ if ce_norm is not None:
185
+ score = 0.80 * score + 0.20 * ce_norm[i]
186
+ merged.append((score, d))
187
+
188
+ merged.sort(key=lambda x: x[0], reverse=True)
189
+ top = [d for _s, d in merged[:k_final]]
190
+ return top
191
 
192
 
193
+ def _build_context_from_docs(docs: List[Dict], query: str, max_blocks: int = 4) -> Tuple[str, List[str]]:
194
+ """
195
+ Build a compact CONTEXT section using best paragraphs from top docs.
196
+ Return (context_text, sources).
197
+ """
198
+ blocks: List[str] = []
199
+ sources: List[str] = []
200
+ for i, d in enumerate(docs[:max_blocks]):
201
+ snip = _best_paragraphs(d["text"], query, max_chars=700)
202
+ src = d.get("source", f"kb:{i}")
203
+ blocks.append(f"[{i+1}] {snip}\n(source: {src})")
204
+ sources.append(src)
205
+ if not blocks:
206
+ return "", []
207
+ prelude = "CONTEXT (use only these facts; if missing, say you don't know):"
208
+ return prelude + "\n\n" + "\n\n".join(blocks), sources
209
+
210
+
211
+ # ----------------------------
212
+ # Service
213
+ # ----------------------------
214
  class ChatService:
215
  def __init__(self, settings: Settings):
216
  self.settings = settings
 
219
  fallback=settings.model.fallback,
220
  provider=getattr(settings.model, "provider", None),
221
  max_retries=2,
222
+ connect_timeout=10.0,
223
+ read_timeout=60.0,
224
  )
225
+ # RAG (singleton)
226
  self.retriever = get_retriever(settings)
227
 
228
+ # Optional cross-encoder (large; disable via env RAG_RERANK=false)
229
+ self.reranker = None
230
+ use_rerank = os.getenv("RAG_RERANK", "true").lower() in ("1", "true", "yes")
231
+ if use_rerank and CrossEncoder is not None:
232
+ try:
233
+ self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-2-v2")
234
+ logger.info("RAG cross-encoder reranker enabled.")
235
+ except Exception as e:
236
+ logger.warning("Reranker disabled: %s", e)
237
+
238
+ # ---------- RAG core ----------
239
+ def _retrieve_best(self, query: str) -> Tuple[str, List[str]]:
240
+ """
241
+ Retrieve many, rerank, and build a compact, high-signal CONTEXT.
242
+ Returns (context_text, sources).
243
+ """
244
  if not self.retriever:
245
  return "", []
246
+
247
+ expanded = _expand_query(query)
248
+ # Retrieve a wider candidate pool, then rerank.
249
+ k_base = max(4, int(self.settings.rag.top_k) * 5)
250
+ try:
251
+ cands = self.retriever.retrieve(expanded, k=k_base) # [{'text','source','score'}]
252
+ except Exception as e:
253
+ logger.warning("Retriever failed (%s); falling back to LLM-only.", e)
254
+ return "", []
255
+
256
+ if not cands:
257
  return "", []
258
+
259
+ top = _rerank_docs(cands, query, k_final=max(3, self.settings.rag.top_k), reranker=self.reranker)
260
+ ctx, sources = _build_context_from_docs(top, query, max_blocks=max(3, self.settings.rag.top_k))
261
+ return ctx, sources
262
 
263
  def _augment(self, query: str) -> Tuple[str, List[str]]:
264
  """
265
  Build the final user message (with optional CONTEXT) and return sources.
266
  """
267
+ ctx, sources = self._retrieve_best(query)
 
 
268
  if ctx:
269
+ user_msg = (
270
+ f"{ctx}\n\n"
271
+ "Based only on the context above, answer the question succinctly.\n"
272
+ f"Question: {query}\nAnswer:"
273
+ )
274
  else:
275
+ user_msg = query # LLM-only fallback
276
+ return user_msg, sources
 
 
277
 
278
+ # ---------- Non-stream ----------
279
  def answer_with_sources(self, query: str) -> Tuple[str, List[str]]:
280
  user_msg, sources = self._augment(query)
281
  text = self.client.chat_nonstream(
282
+ SYSTEM_PROMPT,
283
+ user_msg,
284
  max_tokens=self.settings.model.max_new_tokens,
285
  temperature=self.settings.model.temperature,
286
  )
287
  return text, sources
288
 
289
+ # ---------- Stream ----------
290
  def stream_answer(self, query: str):
291
  user_msg, _ = self._augment(query)
292
+ # SYNC generator yielding token deltas; router wraps in SSE
293
  return self.client.chat_stream(
294
+ SYSTEM_PROMPT,
295
+ user_msg,
296
  max_tokens=self.settings.model.max_new_tokens,
297
  temperature=self.settings.model.temperature,
298
+ )
configs/rag_sources.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Where to pull documentation from when building the RAG knowledge base.
2
+ # You can add/remove repos here; the builder will respect these sources.
3
+
4
+ github:
5
+ # 1) Explicit repos (stable)
6
+ repos:
7
+ - owner: agent-matrix
8
+ name: matrix-cli
9
+ branch: master
10
+ docs_paths: ["docs"] # folders to harvest (recursive)
11
+ include_readme: true
12
+ - owner: agent-matrix
13
+ name: matrix-python-sdk
14
+ branch: master
15
+ docs_paths: ["docs"]
16
+ include_readme: true
17
+ - owner: agent-matrix
18
+ name: matrixlink
19
+ branch: master
20
+ docs_paths: ["docs"]
21
+ include_readme: true
22
+ - owner: agent-matrix
23
+ name: matrix-hub
24
+ branch: master
25
+ docs_paths: ["docs"]
26
+ include_readme: true
27
+
28
+ # 2) Optionally scan an entire org for repos (README + docs/ if present)
29
+ # Comment out if you want only the explicit list above.
30
+ orgs:
31
+ - agent-matrix
32
+
33
+ # Local content in THIS repo (optional but recommended)
34
+ local:
35
+ paths:
36
+ - docs # everything under /docs
37
+ - README.md # root readme
38
+ glob: "**/*.md" # or "**/*.{md,mdx,txt}"
39
+
40
+ # Extra public URLs to pull (optional)
41
+ urls: []
data/kb.jsonl CHANGED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -18,3 +18,8 @@ pytest
18
  ruff
19
  mypy
20
  pytest-asyncio
 
 
 
 
 
 
18
  ruff
19
  mypy
20
  pytest-asyncio
21
+
22
+
23
+ requests>=2.32.0
24
+ beautifulsoup4>=4.12.3 # only used if you later add generic HTML URLs
25
+ PyYAML>=6.0.1
scripts/build_kb.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Builds/refreshes the local RAG KB (data/kb.jsonl) from GitHub + local docs.
4
+
5
+ Usage:
6
+ python scripts/build_kb.py --config configs/rag_sources.yaml --out data/kb.jsonl
7
+ python scripts/build_kb.py --config ... --out ... --force
8
+ """
9
+
10
+ from __future__ import annotations
11
+ import argparse
12
+ import logging
13
+ import os
14
+ import sys
15
+ from pathlib import Path
16
+
17
+ # --- Ensure THIS repo is first on sys.path (avoid clashing 'app' packages) ---
18
+ ROOT = Path(__file__).resolve().parent.parent
19
+ sys.path.insert(0, str(ROOT))
20
+
21
+ logger = logging.getLogger("build_kb")
22
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
23
+
24
+ # Import the builder from this project
25
+ try:
26
+ from app.core.rag.build import build_kb_from_config, ensure_kb # type: ignore
27
+ except Exception as e: # pragma: no cover
28
+ logger.error("Failed importing KB builder from app.core.rag.build: %s", e)
29
+ logger.error("Make sure you're running from the project root and PYTHONPATH includes '.'.")
30
+ sys.exit(2)
31
+
32
+ def main() -> int:
33
+ p = argparse.ArgumentParser()
34
+ p.add_argument("--config", required=True, help="Path to configs/rag_sources.yaml")
35
+ p.add_argument("--out", required=True, help="Output JSONL file, e.g., data/kb.jsonl")
36
+ p.add_argument("--force", action="store_true", help="Delete output file first, then rebuild")
37
+ args = p.parse_args()
38
+
39
+ out_path = Path(args.out)
40
+ if args.force and out_path.exists():
41
+ logger.info("Removing existing %s", out_path)
42
+ out_path.unlink()
43
+
44
+ # If you want a one-liner that skips if exists, use ensure_kb:
45
+ # created = ensure_kb(out_jsonl=args.out, config_path=args.config, skip_if_exists=True)
46
+ # logger.info("KB %s at %s", "ready" if created else "unchanged", args.out)
47
+
48
+ # Otherwise, always (re)build:
49
+ n = build_kb_from_config(config_path=args.config, out_jsonl=args.out)
50
+ logger.info("Wrote %d records to %s", n, args.out)
51
+ return 0
52
+
53
+ if __name__ == "__main__":
54
+ raise SystemExit(main())