openbrain / app.py
DreamyDetective's picture
Upload folder using huggingface_hub
9e7395e verified
"""
Open Brain — Gradio Space: local embeddings (sentence-transformers), SQLite cache,
HNSW approximate search, Hugging Face Datasets as remote store.
"""
from __future__ import annotations
import asyncio
import json
import os
import sqlite3
import threading
import uuid
from datetime import datetime, timezone
from typing import Any
import gradio as gr
import numpy as np
try:
import hnswlib # provided by chroma-hnswlib (Chroma fork; same API as upstream hnswlib)
except ImportError as exc:
raise ImportError(
"HNSW bindings not found. Install the Chroma fork: pip install chroma-hnswlib"
) from exc
from datasets import Dataset, Features, Sequence, Value
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
EMBEDDING_MODEL = os.environ.get(
"EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
)
HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "").strip()
SQLITE_PATH = os.environ.get("SQLITE_PATH", "thoughts.db")
HNSW_M = int(os.environ.get("HNSW_M", "16"))
HNSW_EF_CONSTRUCTION = int(os.environ.get("HNSW_EF_CONSTRUCTION", "200"))
HNSW_EF_SEARCH = int(os.environ.get("HNSW_EF_SEARCH", "64"))
# Retrieve extra neighbors for post-filtering (threshold + metadata).
HNSW_K_MULTIPLIER = int(os.environ.get("HNSW_K_MULTIPLIER", "10"))
# Gradio MCP server (SSE/HTTP tools). Set GRADIO_MCP_SERVER=false to disable.
os.environ.setdefault("GRADIO_MCP_SERVER", "true")
def _utc_now_iso() -> str:
"""Return the current UTC time as an ISO 8601 string."""
return datetime.now(timezone.utc).isoformat()
def _hub_token() -> str | None:
"""Return the Hugging Face Hub token from ``HF_TOKEN`` or ``HUGGING_FACE_HUB_TOKEN``."""
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
def json_contains(container: Any, subset: Any) -> bool:
"""Return whether ``container`` contains ``subset`` (PostgreSQL JSONB ``@>`` semantics).
For dicts, every key in ``subset`` must exist in ``container`` with matching
values; nested dicts are checked recursively. Empty ``subset`` matches any
``container``.
"""
if subset is None or (isinstance(subset, dict) and len(subset) == 0):
return True
if not isinstance(subset, dict) or not isinstance(container, dict):
return container == subset
for key, sub_val in subset.items():
if key not in container:
return False
cur = container[key]
if isinstance(sub_val, dict) and isinstance(cur, dict):
if not json_contains(cur, sub_val):
return False
elif cur != sub_val:
return False
return True
def _parse_metadata(raw: str | dict | None) -> dict[str, Any]:
"""Parse stored metadata into a dict (accepts JSON string, dict, or None)."""
if raw is None:
return {}
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
try:
return json.loads(raw) if raw else {}
except json.JSONDecodeError:
return {}
return {}
# ---------------------------------------------------------------------------
# Thought store
# ---------------------------------------------------------------------------
class ThoughtStore:
"""Persist thoughts in SQLite, search embeddings with HNSW, sync snapshots to the Hub.
``capture_thought``, ``search_thoughts``, ``list_thoughts``, and ``thought_stats`` are
registered as MCP tools (``@gr.mcp.tool()`` + ``gr.api()`` on the Blocks app).
"""
def __init__(self) -> None:
"""Initialize locks and lazy-loaded model, DB connection, and index state."""
self._lock = threading.Lock()
self._conn: sqlite3.Connection | None = None
self._model: SentenceTransformer | None = None
self._dim: int | None = None
self._index: hnswlib.Index | None = None
self._label_to_id: list[str] = []
self._startup_message = ""
def _connect(self) -> sqlite3.Connection:
"""Return the shared SQLite connection, creating it on first use."""
if self._conn is None:
self._conn = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
return self._conn
def _init_sqlite(self) -> None:
"""Create the ``thoughts`` table if it does not exist."""
c = self._connect()
c.execute(
"""
CREATE TABLE IF NOT EXISTS thoughts (
id TEXT PRIMARY KEY NOT NULL,
content TEXT NOT NULL,
embedding BLOB,
metadata TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
"""
)
c.commit()
def _ensure_model(self) -> SentenceTransformer:
"""Load ``EMBEDDING_MODEL`` once and cache its output dimension."""
if self._model is None:
self._model = SentenceTransformer(EMBEDDING_MODEL)
self._dim = int(self._model.get_sentence_embedding_dimension())
return self._model
def embed(self, text: str) -> np.ndarray:
"""Encode ``text`` to a single L2-normalized float32 embedding vector."""
model = self._ensure_model()
v = model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
return np.asarray(v, dtype=np.float32)
async def embed_async(self, text: str) -> np.ndarray:
"""Run :meth:`embed` in a worker thread so the event loop stays responsive."""
return await asyncio.to_thread(self.embed, text)
async def push_to_hub_async(self) -> str:
"""Run Hub push in a worker thread; failures become a status string (same as sync capture)."""
return await asyncio.to_thread(self._push_to_hub_catch)
def rebuild_hnsw(self) -> None:
"""Rebuild the in-memory HNSW index and ``_label_to_id`` from SQLite rows with embeddings."""
if self._dim is None:
self._ensure_model()
assert self._dim is not None
conn = self._connect()
rows = conn.execute(
"SELECT id, embedding FROM thoughts WHERE embedding IS NOT NULL ORDER BY created_at ASC"
).fetchall()
self._index = None
self._label_to_id = []
if not rows:
return
data = np.stack(
[np.frombuffer(r["embedding"], dtype=np.float32) for r in rows]
)
n, dim = data.shape
if dim != self._dim:
raise ValueError(
f"Embedding dim {dim} != model dim {self._dim}; clear DB or match EMBEDDING_MODEL."
)
index = hnswlib.Index(space="cosine", dim=dim)
index.init_index(
max_elements=max(n, 1),
ef_construction=HNSW_EF_CONSTRUCTION,
M=HNSW_M,
)
index.set_ef(HNSW_EF_SEARCH)
labels = np.arange(n, dtype=np.int64)
index.add_items(data, labels)
self._index = index
self._label_to_id = [str(r["id"]) for r in rows]
def load_from_hub(self) -> str:
"""Replace local SQLite content from ``HF_DATASET_REPO`` and rebuild HNSW.
If ``HF_DATASET_REPO`` is unset, only initializes SQLite and the index.
On Hub errors, leaves or clears state and returns a short status message.
"""
if not HF_DATASET_REPO:
self._init_sqlite()
self.rebuild_hnsw()
return "HF_DATASET_REPO not set — using local SQLite only (no Hub load)."
self._init_sqlite()
conn = self._connect()
token = _hub_token()
try:
ds = load_dataset(HF_DATASET_REPO, split="train", token=token)
except Exception as e:
self.rebuild_hnsw()
return f"Hub load failed ({e!r}); starting empty or with existing SQLite."
conn.execute("DELETE FROM thoughts")
rows_inserted = 0
for row in ds:
tid = str(row["id"])
content = str(row["content"])
emb = row["embedding"]
if hasattr(emb, "tolist"):
emb_list = emb.tolist()
else:
emb_list = list(emb)
meta_raw = row.get("metadata", "{}")
if isinstance(meta_raw, dict):
meta_s = json.dumps(meta_raw)
else:
meta_s = str(meta_raw) if meta_raw else "{}"
created = str(row.get("created_at", _utc_now_iso()))
updated = str(row.get("updated_at", created))
emb_bytes = np.asarray(emb_list, dtype=np.float32).tobytes()
conn.execute(
"""INSERT INTO thoughts (id, content, embedding, metadata, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)""",
(tid, content, emb_bytes, meta_s, created, updated),
)
rows_inserted += 1
conn.commit()
self.rebuild_hnsw()
return f"Loaded {rows_inserted} thought(s) from {HF_DATASET_REPO} into SQLite; HNSW index built."
def push_to_hub(self) -> str:
"""Serialize all SQLite rows to a ``datasets.Dataset`` and ``push_to_hub``."""
if not HF_DATASET_REPO:
return "HF_DATASET_REPO not set; nothing pushed."
token = _hub_token()
if not token:
return "No HF_TOKEN / HUGGING_FACE_HUB_TOKEN; cannot push."
conn = self._connect()
rows = conn.execute(
"SELECT id, content, embedding, metadata, created_at, updated_at FROM thoughts ORDER BY created_at ASC"
).fetchall()
records: list[dict[str, Any]] = []
for r in rows:
emb = (
np.frombuffer(r["embedding"], dtype=np.float32).tolist()
if r["embedding"]
else []
)
records.append(
{
"id": str(r["id"]),
"content": str(r["content"]),
"embedding": emb,
"metadata": str(r["metadata"] or "{}"),
"created_at": str(r["created_at"]),
"updated_at": str(r["updated_at"]),
}
)
features = Features(
{
"id": Value("string"),
"content": Value("string"),
"embedding": Sequence(Value("float32")),
"metadata": Value("string"),
"created_at": Value("string"),
"updated_at": Value("string"),
}
)
if records:
ds = Dataset.from_list(records, features=features)
else:
ds = Dataset.from_dict(
{
"id": [],
"content": [],
"embedding": [],
"metadata": [],
"created_at": [],
"updated_at": [],
},
features=features,
)
ds.push_to_hub(HF_DATASET_REPO, token=token)
return f"Pushed {len(records)} thought(s) to {HF_DATASET_REPO}."
def _insert_captured_row(
self,
tid: str,
content: str,
emb: np.ndarray,
merged: dict[str, Any],
now: str,
) -> None:
"""Insert one row and update HNSW (must run under thread pool; uses :attr:`_lock`)."""
emb_bytes = emb.tobytes()
with self._lock:
conn = self._connect()
conn.execute(
"""INSERT INTO thoughts (id, content, embedding, metadata, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)""",
(tid, content, emb_bytes, json.dumps(merged), now, now),
)
conn.commit()
self._incremental_add_to_index(tid, emb)
def _push_to_hub_catch(self) -> str:
try:
return self.push_to_hub()
except Exception as e:
return f"Hub push failed: {e!r}"
@gr.mcp.tool()
async def capture_thought(
self, content: str, metadata_json: str = "{}"
) -> dict[str, Any]:
"""Save a new thought with automatic embedding and optional structured metadata.
Embeds ``content`` with the local SentenceTransformers model, stores the row in
SQLite, updates the HNSW index, and pushes a dataset snapshot to the Hub when
``HF_DATASET_REPO`` and a token are configured.
Embedding and Hub push run in worker threads so the server event loop is not blocked.
Args:
content: The free-form thought text to store.
metadata_json: A JSON **object** as a string (e.g. ``{"topic": "...", "type": "note"}``).
Merged into stored metadata. Use ``"{}"`` if you have no extra fields.
"""
content = (content or "").strip()
if not content:
return {"error": "content is empty"}
extra: dict[str, Any] = {}
if metadata_json and metadata_json.strip():
try:
extra = json.loads(metadata_json)
except json.JSONDecodeError as e:
return {"error": f"invalid metadata JSON: {e}"}
emb = await self.embed_async(content)
tid = str(uuid.uuid4())
now = _utc_now_iso()
merged = {**extra}
await asyncio.to_thread(
self._insert_captured_row, tid, content, emb, merged, now
)
push_msg = await self.push_to_hub_async()
return {
"status": "captured",
"id": tid,
"created_at": now,
"metadata": merged,
"hub": push_msg,
}
def _incremental_add_to_index(self, tid: str, emb: np.ndarray) -> None:
"""Append one vector to the HNSW index or rebuild if the index was empty."""
if self._dim is None:
self._ensure_model()
assert self._dim is not None
if self._index is None:
self.rebuild_hnsw()
return
n = len(self._label_to_id)
cur_max = int(self._index.get_max_elements())
if n >= cur_max:
self._index.resize_index(max(cur_max * 2, n + 1))
new_label = np.array([n], dtype=np.int64)
self._index.add_items(emb.reshape(1, -1), new_label)
self._label_to_id.append(tid)
def _search_with_embedding(
self,
q_emb: np.ndarray,
threshold: float,
limit: int,
filt: dict[str, Any],
) -> dict[str, Any]:
"""HNSW search and post-filter using a precomputed query embedding."""
with self._lock:
conn = self._connect()
if not self._label_to_id or self._index is None:
return {"count": 0, "thoughts": []}
n_lab = len(self._label_to_id)
k_fetch = min(
max(limit * HNSW_K_MULTIPLIER, max(limit, 10)),
n_lab,
)
if k_fetch < 1:
return {"count": 0, "thoughts": []}
labels, dists = self._index.knn_query(q_emb, k=k_fetch)
lab = labels[0]
dist = dists[0]
results: list[dict[str, Any]] = []
for label, d in zip(lab.tolist(), dist.tolist()):
tid = self._label_to_id[int(label)]
row = conn.execute(
"SELECT id, content, metadata, created_at FROM thoughts WHERE id = ?",
(tid,),
).fetchone()
if row is None:
continue
meta = _parse_metadata(row["metadata"])
if not json_contains(meta, filt):
continue
# chroma-hnswlib / hnswlib cosine space: distance ~ 1 - cos_sim for L2-normalized vectors
similarity = float(1.0 - d)
if similarity <= threshold:
continue
results.append(
{
"id": row["id"],
"content": row["content"],
"metadata": meta,
"similarity": similarity,
"created_at": row["created_at"],
}
)
results.sort(key=lambda x: x["similarity"], reverse=True)
results = results[:limit]
return {"count": len(results), "thoughts": results}
@gr.mcp.tool()
async def search_thoughts(
self,
query: str,
threshold: float = 0.4,
limit: int = 10,
filter_json: str = "{}",
) -> dict[str, Any]:
"""Semantic search over stored thoughts using approximate nearest neighbors (HNSW).
Embeds ``query``, retrieves candidate neighbors by cosine distance, keeps rows whose
similarity exceeds ``threshold``, optionally requires metadata to contain a subset
object (JSON containment, like PostgreSQL ``@>``), and returns up to ``limit`` hits.
Query embedding runs in a worker thread so the event loop is not blocked during encoding.
Args:
query: Natural-language search query; embedded and compared to stored thoughts.
threshold: Minimum cosine similarity in ``[0, 1]`` (higher = stricter). Default 0.4.
limit: Maximum number of thoughts to return after ranking (1–50 typical).
filter_json: JSON object as a string; if not ``"{}"``, each result’s metadata must
**contain** this object (same keys/values; nested objects supported). Example:
``{"type": "task"}`` keeps only thoughts whose metadata has ``"type": "task"``.
"""
query = (query or "").strip()
if not query:
return {"error": "query is empty", "thoughts": []}
filt: dict[str, Any] = {}
if filter_json and filter_json.strip():
try:
filt = json.loads(filter_json)
except json.JSONDecodeError as e:
return {"error": f"invalid filter JSON: {e}", "thoughts": []}
q_emb = await self.embed_async(query)
return await asyncio.to_thread(
self._search_with_embedding, q_emb, threshold, limit, filt
)
@gr.mcp.tool()
def list_thoughts(
self,
limit: int = 20,
offset: int = 0,
filter_json: str = "{}",
) -> dict[str, Any]:
"""List stored thoughts by recency with optional metadata filtering and pagination.
Thoughts are ordered from newest to oldest by ``created_at``. Filtering uses JSON
containment on metadata (same rules as ``search_thoughts``).
Args:
limit: Maximum rows to return after ``offset`` (e.g. 20).
offset: Number of matching rows to skip from the start of the newest-first list.
filter_json: JSON object as a string; if not ``"{}"``, only thoughts whose metadata
contains this object are considered before ``offset``/``limit``.
"""
filt: dict[str, Any] = {}
if filter_json and filter_json.strip():
try:
filt = json.loads(filter_json)
except json.JSONDecodeError as e:
return {"error": f"invalid filter JSON: {e}", "thoughts": []}
with self._lock:
conn = self._connect()
rows = conn.execute(
"SELECT id, content, metadata, created_at, updated_at FROM thoughts ORDER BY created_at DESC"
).fetchall()
out: list[dict[str, Any]] = []
for r in rows:
meta = _parse_metadata(r["metadata"])
if filt and not json_contains(meta, filt):
continue
out.append(
{
"id": r["id"],
"content": r["content"],
"metadata": meta,
"created_at": r["created_at"],
"updated_at": r["updated_at"],
}
)
sliced = out[offset : offset + limit]
return {"count": len(sliced), "thoughts": sliced}
@gr.mcp.tool()
def thought_stats(self) -> dict[str, Any]:
"""Aggregate statistics over all thoughts in the local store (no parameters).
Returns:
A dict with ``total_thoughts``, ``embedded_thoughts``, ``earliest`` and ``latest``
ISO timestamps, ``unique_topics`` (distinct string ``metadata.topic`` values), and
``by_type`` (counts keyed by ``metadata.type``, or ``"unknown"`` when missing).
"""
with self._lock:
conn = self._connect()
rows = conn.execute(
"SELECT embedding, metadata, created_at FROM thoughts"
).fetchall()
total = len(rows)
embedded = sum(1 for r in rows if r["embedding"] is not None)
dates = [r["created_at"] for r in rows if r["created_at"]]
topics: set[str] = set()
by_type: dict[str, int] = {}
for r in rows:
meta = _parse_metadata(r["metadata"])
t = meta.get("topic")
if isinstance(t, str) and t:
topics.add(t)
typ = meta.get("type")
key = str(typ) if typ is not None else "unknown"
by_type[key] = by_type.get(key, 0) + 1
return {
"total_thoughts": total,
"embedded_thoughts": embedded,
"earliest": min(dates) if dates else None,
"latest": max(dates) if dates else None,
"unique_topics": len(topics),
"by_type": by_type,
}
# ---------------------------------------------------------------------------
# Gradio
# ---------------------------------------------------------------------------
store = ThoughtStore()
STARTUP_STATUS = store.load_from_hub()
async def ui_capture(content: str, metadata: str) -> str:
"""Gradio handler: capture a thought and return pretty-printed JSON."""
result = await store.capture_thought(content, metadata)
return json.dumps(result, indent=2)
async def ui_search(query: str, threshold: float, limit: int, filt: str) -> str:
"""Gradio handler: semantic search and return pretty-printed JSON."""
result = await store.search_thoughts(
query, threshold, int(limit), filt
)
return json.dumps(result, indent=2)
def ui_list(limit: int, offset: int, filt: str) -> str:
"""Gradio handler: list thoughts and return pretty-printed JSON."""
return json.dumps(
store.list_thoughts(int(limit), int(offset), filt), indent=2
)
def ui_stats() -> str:
"""Gradio handler: aggregate stats as pretty-printed JSON."""
return json.dumps(store.thought_stats(), indent=2)
with gr.Blocks(title="Open Brain") as demo:
gr.api(store.capture_thought)
gr.api(store.search_thoughts)
gr.api(store.list_thoughts)
gr.api(store.thought_stats)
gr.Markdown(
"## Open Brain\n"
"Capture thoughts with **local embeddings** (SentenceTransformers), "
"**HNSW** search over SQLite, sync to a **Hugging Face Dataset**.\n\n"
"**MCP:** tools are exposed when the Gradio MCP server is enabled "
"(see ``GRADIO_MCP_SERVER``; endpoint is under ``/gradio_api/mcp/``)."
)
startup = gr.Textbox(
label="Startup / Hub load status",
value=STARTUP_STATUS,
interactive=False,
)
with gr.Tab("Capture"):
cap_content = gr.Textbox(label="Content", lines=6)
cap_meta = gr.Textbox(
label="Optional metadata (JSON object)",
lines=4,
value="{}",
)
cap_btn = gr.Button("Capture thought")
cap_out = gr.Code(label="Result", language="json")
cap_btn.click(
ui_capture, [cap_content, cap_meta], cap_out, api_visibility="private"
)
with gr.Tab("Search"):
s_q = gr.Textbox(label="Query", lines=2)
s_thr = gr.Slider(0, 1, value=0.4, step=0.05, label="Minimum similarity")
s_lim = gr.Number(value=10, label="Limit", precision=0)
s_f = gr.Textbox(label="Optional metadata filter (JSON)", lines=3, value="{}")
s_btn = gr.Button("Search")
s_out = gr.Code(label="Results", language="json")
s_btn.click(
ui_search, [s_q, s_thr, s_lim, s_f], s_out, api_visibility="private"
)
with gr.Tab("List"):
l_lim = gr.Number(value=20, label="Limit", precision=0)
l_off = gr.Number(value=0, label="Offset", precision=0)
l_f = gr.Textbox(label="Optional metadata filter (JSON)", lines=3, value="{}")
l_btn = gr.Button("List")
l_out = gr.Code(label="Results", language="json")
l_btn.click(ui_list, [l_lim, l_off, l_f], l_out, api_visibility="private")
with gr.Tab("Stats"):
st_btn = gr.Button("Refresh stats")
st_out = gr.Code(label="Stats", language="json")
st_btn.click(ui_stats, None, st_out, api_visibility="private")
demo.queue()
if __name__ == "__main__":
demo.launch(mcp_server=True)