Financial_bot / src /embedder.py
Pushkya's picture
Upload 30 files
8299003 verified
Raw
History Blame Contribute Delete
17.6 kB
"""
embedder.py
===========
Phase 4 – Embedding & Vector Store Ingestion
Converts chunks produced by chunker.py into vector embeddings and persists
them in a ChromaDB collection.
Model: sentence-transformers/all-MiniLM-L6-v2
- 384-dimensional embeddings
- 256-token context window (matches our chunking max_tokens exactly)
- Fast inference, strong retrieval quality for English financial text
- Same model for indexing AND retrieval β†’ vectors are in the same space
ChromaDB collection: financial_docs
- Single collection for ALL document types (Morningstar + SEC filings)
- Cosine similarity space (best for sentence transformers)
- Upsert semantics β†’ re-running is fully safe, no duplicate vectors
- Metadata filters enable per-source / per-doc-type / per-ticker queries:
{"source": "morningstar"}
{"doc_type": "10-K", "ticker": "AAPL"}
{"chunk_type": "table"}
ChromaDB metadata constraints
ChromaDB only accepts scalar values (str, int, float, bool).
Lists (e.g., col_headers) are JSON-serialised to strings.
None values are replaced with "" to avoid insertion errors.
Output format per stored vector
id : chunk_id (e.g. "ptc01302411420_text_0042")
document : chunk text (prose or markdown table)
metadata : all chunk metadata + chunk_type, is_atomic, doc_id
Usage (as a module)
-------------------
from src.embedder import DocumentEmbedder
emb = DocumentEmbedder()
emb.embed_document("data/chunks/morningstar/ptc01302411420_chunks.json")
results = emb.query("What is PTC's revenue growth?", n_results=5)
Usage (as a script)
-------------------
python src/embedder.py
python src/embedder.py --force # re-embed even if already stored
"""
import os
import json
import hashlib
import logging
from pathlib import Path
from datetime import datetime, timezone
from dotenv import load_dotenv
load_dotenv()
# "chromadb" (local default) or "qdrant" (set QDRANT_URL to auto-switch)
_BACKEND = "qdrant" if os.getenv("QDRANT_URL") else "chromadb"
# ── Logging ────────────────────────────────────────────────────────────────────
logging.basicConfig(
level = logging.INFO,
format = "%(asctime)s %(levelname)-8s %(message)s",
)
log = logging.getLogger(__name__)
# ── Paths ──────────────────────────────────────────────────────────────────────
BASE_DIR = Path(__file__).parent.parent
CHUNKS_DIR = BASE_DIR / "data" / "chunks"
VECTORSTORE_DIR = BASE_DIR / "data" / "vectorstore"
# ── Constants ──────────────────────────────────────────────────────────────────
COLLECTION_NAME = "financial_docs"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
BATCH_SIZE = 100 # number of chunks per ChromaDB upsert call
# ══════════════════════════════════════════════════════════════════════════════
# METADATA SANITISATION
# ──────────────────────────────────────────────────────────────────────────────
# ChromaDB rejects metadata values that are:
# - None β†’ replace with ""
# - list β†’ JSON-serialise to a string
# - other β†’ cast to str as a safety fallback
# ══════════════════════════════════════════════════════════════════════════════
def _sanitize_metadata(meta: dict) -> dict:
"""
Convert metadata dict to ChromaDB-compatible scalar values.
ChromaDB accepted types: str | int | float | bool
"""
clean = {}
for k, v in meta.items():
if isinstance(v, (str, int, float, bool)):
clean[k] = v
elif v is None:
clean[k] = ""
elif isinstance(v, list):
clean[k] = json.dumps(v, ensure_ascii=False)
else:
clean[k] = str(v)
return clean
# ══════════════════════════════════════════════════════════════════════════════
# MAIN EMBEDDER CLASS
# ══════════════════════════════════════════════════════════════════════════════
class DocumentEmbedder:
"""
Embeds document chunks and persists them in a ChromaDB collection.
All documents (Morningstar PDFs, SEC filings) share one collection so
cross-document similarity search works out of the box. Use metadata
filters to restrict retrieval to a specific source or document type.
"""
def __init__(
self,
vectorstore_dir : Path = VECTORSTORE_DIR,
collection_name : str = COLLECTION_NAME,
embedding_model : str = EMBEDDING_MODEL,
):
self._collection_name = collection_name
self._embedding_model = embedding_model
self._backend = _BACKEND
if self._backend == "qdrant":
self._init_qdrant(collection_name, embedding_model)
else:
self._init_chromadb(vectorstore_dir, collection_name, embedding_model)
def _init_chromadb(self, vectorstore_dir, collection_name, embedding_model):
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
self.vectorstore_dir = Path(vectorstore_dir)
self.vectorstore_dir.mkdir(parents=True, exist_ok=True)
self.client = chromadb.PersistentClient(path=str(self.vectorstore_dir))
self.ef = SentenceTransformerEmbeddingFunction(model_name=embedding_model)
self.collection = self.client.get_or_create_collection(
name = collection_name,
embedding_function = self.ef,
metadata = {"hnsw:space": "cosine"},
)
log.info(
f"[ChromaDB] collection '{collection_name}' ready "
f"({self.collection.count()} vectors already stored)"
)
def _init_qdrant(self, collection_name, embedding_model):
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance
from sentence_transformers import SentenceTransformer
self._qdrant = QdrantClient(
url = os.getenv("QDRANT_URL"),
api_key = os.getenv("QDRANT_API_KEY"),
)
self._embed_model = SentenceTransformer(embedding_model)
# Create collection if it doesn't exist
existing = [c.name for c in self._qdrant.get_collections().collections]
if collection_name not in existing:
self._qdrant.create_collection(
collection_name = collection_name,
vectors_config = VectorParams(size=384, distance=Distance.COSINE),
)
log.info(f"[Qdrant] Created collection '{collection_name}'")
count = self._qdrant.count(collection_name).count
log.info(
f"[Qdrant] collection '{collection_name}' ready "
f"({count} vectors already stored)"
)
# ── Embed one document ─────────────────────────────────────────────────────
def embed_document(
self,
chunks_json_path : str | Path,
force : bool = False,
) -> int:
"""
Embed all chunks from one document and upsert into ChromaDB.
Args:
chunks_json_path : path to the _chunks.json file (output of chunker.py)
force : if True, re-embed even if already stored
Returns:
number of chunks upserted (0 if skipped)
"""
chunks_json_path = Path(chunks_json_path)
with open(chunks_json_path) as f:
doc = json.load(f)
chunks = doc["chunks"]
all_ids = [c["chunk_id"] for c in chunks]
# ── Check which chunks are already stored ─────────────────────────────
if not force:
existing = self.collection.get(ids=all_ids, include=[])
already = set(existing["ids"])
to_embed = [c for c in chunks if c["chunk_id"] not in already]
if not to_embed:
log.info(
f"SKIP {chunks_json_path.name} "
f"(all {len(all_ids)} chunks already in store)"
)
return 0
log.info(
f"Embedding: {chunks_json_path.name} "
f"({len(to_embed)} new, {len(already)} already stored)"
)
else:
to_embed = chunks
log.info(f"Embedding: {chunks_json_path.name} ({len(to_embed)} chunks)")
# ── Build inputs ──────────────────────────────────────────────────────
ids = []
documents = []
metadatas = []
for chunk in to_embed:
ids.append(chunk["chunk_id"])
documents.append(chunk["text"])
meta = _sanitize_metadata(chunk["metadata"])
meta["chunk_type"] = chunk["chunk_type"]
meta["is_atomic"] = chunk["is_atomic"]
meta["doc_id"] = chunk["doc_id"]
metadatas.append(meta)
# ── Upsert in batches ─────────────────────────────────────────────────
if self._backend == "qdrant":
self._upsert_qdrant(ids, documents, metadatas)
else:
self._upsert_chromadb(ids, documents, metadatas)
return len(ids)
def _upsert_chromadb(self, ids, documents, metadatas):
for i in range(0, len(ids), BATCH_SIZE):
batch_ids = ids[i : i + BATCH_SIZE]
batch_docs = documents[i : i + BATCH_SIZE]
batch_meta = metadatas[i : i + BATCH_SIZE]
self.collection.upsert(
ids = batch_ids,
documents = batch_docs,
metadatas = batch_meta,
)
log.info(
f" Batch {i // BATCH_SIZE + 1} "
f"({len(batch_ids)} chunks) "
f"total in store: {self.collection.count()}"
)
def _upsert_qdrant(self, ids, documents, metadatas):
from qdrant_client.models import PointStruct
log.info(f" Encoding {len(documents)} chunks with SentenceTransformer...")
embeddings = self._embed_model.encode(
documents,
batch_size = 32,
show_progress_bar = False,
)
for i in range(0, len(ids), BATCH_SIZE):
batch_ids = ids[i : i + BATCH_SIZE]
batch_docs = documents[i : i + BATCH_SIZE]
batch_meta = metadatas[i : i + BATCH_SIZE]
batch_embs = embeddings[i : i + BATCH_SIZE]
points = [
PointStruct(
# deterministic integer ID from chunk_id string
id = int(hashlib.md5(cid.encode()).hexdigest(), 16) % (2 ** 63),
vector = emb.tolist(),
payload = {"text": doc, "chunk_id": cid, **meta},
)
for cid, doc, meta, emb in zip(batch_ids, batch_docs, batch_meta, batch_embs)
]
self._qdrant.upsert(
collection_name = self._collection_name,
points = points,
)
total = self._qdrant.count(self._collection_name).count
log.info(
f" Batch {i // BATCH_SIZE + 1} "
f"({len(points)} chunks) "
f"total in store: {total}"
)
# ── Batch embed all documents ──────────────────────────────────────────────
def embed_all(
self,
chunks_dir : Path = CHUNKS_DIR,
force : bool = False,
) -> dict:
"""
Embed all *_chunks.json files found under chunks_dir.
Returns:
dict mapping filename β†’ number of chunks upserted
"""
chunk_files = sorted(Path(chunks_dir).rglob("*_chunks.json"))
log.info(f"Found {len(chunk_files)} chunk files under {chunks_dir}")
summary = {}
for cf in chunk_files:
try:
n = self.embed_document(cf, force=force)
summary[cf.name] = n
except Exception as e:
log.error(f" FAILED {cf.name}: {e}")
summary[cf.name] = -1
return summary
# ── Query ──────────────────────────────────────────────────────────────────
def query(
self,
text : str,
n_results : int = 5,
filters : dict = None,
) -> list[dict]:
"""
Similarity search against the collection.
Args:
text : natural language query (embedded on-the-fly)
n_results : number of top results to return
filters : ChromaDB where clause, e.g. {"source": "morningstar"}
or {"$and": [{"doc_type": "10-K"}, {"ticker": "AAPL"}]}
Returns:
list of dicts with keys: id, text, metadata, distance
(distance is cosine distance: 0 = identical, 2 = opposite)
"""
kwargs = {
"query_texts" : [text],
"n_results" : n_results,
"include" : ["documents", "metadatas", "distances"],
}
if filters:
kwargs["where"] = filters
result = self.collection.query(**kwargs)
return [
{
"id" : result["ids"][0][i],
"text" : result["documents"][0][i],
"metadata" : result["metadatas"][0][i],
"distance" : result["distances"][0][i],
}
for i in range(len(result["ids"][0]))
]
# ── Collection stats ───────────────────────────────────────────────────────
def get_stats(self) -> dict:
"""Return a summary of what is stored in the collection."""
if self._backend == "qdrant":
count = self._qdrant.count(self._collection_name).count
if count == 0:
return {"total_vectors": 0, "sources": [], "doc_ids": [], "chunk_types": []}
records, _ = self._qdrant.scroll(
collection_name = self._collection_name,
limit = count,
with_payload = True,
with_vectors = False,
)
metas = [r.payload for r in records]
else:
count = self.collection.count()
if count == 0:
return {"total_vectors": 0, "sources": [], "doc_ids": [], "chunk_types": []}
metas = self.collection.get(limit=count, include=["metadatas"])["metadatas"]
return {
"total_vectors" : count,
"sources" : sorted({m.get("source", "") for m in metas}),
"doc_ids" : sorted({m.get("doc_id", "") for m in metas}),
"chunk_types" : sorted({m.get("chunk_type", "") for m in metas}),
}
# ── Entry point ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import sys
force = "--force" in sys.argv
log.info("=" * 60)
log.info("Phase 4 – Document Embedder (all-MiniLM-L6-v2 + ChromaDB)")
log.info("=" * 60)
embedder = DocumentEmbedder()
summary = embedder.embed_all(force=force)
log.info("\n" + "=" * 60)
log.info("Embedding complete.")
total = 0
for fname, n in summary.items():
status = f"{n:>5} upserted" if n >= 0 else " FAILED"
log.info(f" {fname:55s} {status}")
if n > 0:
total += n
log.info(f" {'TOTAL NEW VECTORS':55s} {total:>5}")
stats = embedder.get_stats()
log.info(f"\nCollection total: {stats['total_vectors']} vectors")
log.info(f" Sources : {stats['sources']}")
log.info(f" Documents : {stats['doc_ids']}")
log.info("=" * 60)