BactKing / rag /rag_embedder.py
EphAsad's picture
Update rag/rag_embedder.py
d0e80d5 verified
# rag/rag_embedder.py
# ============================================================
# Embedding utilities for RAG (knowledge base + queries)
# Uses a SentenceTransformer model for dense embeddings.
# ============================================================
from __future__ import annotations
import os
import json
from typing import List, Dict, Any
import numpy as np
from sentence_transformers import SentenceTransformer
# ------------------------------------------------------------
# CONFIG
# ------------------------------------------------------------
EMBEDDING_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
_model: SentenceTransformer | None = None
# ------------------------------------------------------------
# MODEL LOADING
# ------------------------------------------------------------
def get_embedder() -> SentenceTransformer:
global _model
if _model is None:
_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
return _model
# ------------------------------------------------------------
# EMBEDDING
# ------------------------------------------------------------
def embed_text(text: str, normalize: bool = True) -> np.ndarray:
"""
Embed a single piece of text.
Returns a 1D numpy array (MPNet: 768-dim).
"""
model = get_embedder()
emb = model.encode(
[text],
show_progress_bar=False,
normalize_embeddings=normalize,
)
return emb[0]
def embed_texts(texts: List[str], normalize: bool = True) -> np.ndarray:
"""
Embed a list of strings -> (N, D) numpy array.
"""
model = get_embedder()
return model.encode(
texts,
show_progress_bar=False,
normalize_embeddings=normalize,
)
# ------------------------------------------------------------
# INDEX LOADING
# ------------------------------------------------------------
def load_kb_index(path: str = "data/rag/index/kb_index.json") -> Dict[str, Any]:
"""
Load the RAG knowledge base index JSON.
Expected format:
{
"version": int,
"model_name": str,
"records": [
{
"id": str,
"genus": str,
"species": str | null,
"level": "genus" | "species",
"chunk_id": int,
"source_file": str,
"text": str,
"embedding": [float, ...]
}
]
}
"""
if not os.path.exists(path):
raise FileNotFoundError(f"KB index not found at {path}")
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
index_model = data.get("model_name")
if index_model != EMBEDDING_MODEL_NAME:
raise ValueError(
f"KB index built with '{index_model}', "
f"but current embedder is '{EMBEDDING_MODEL_NAME}'. "
"Rebuild the index."
)
# Convert embeddings to numpy arrays
for rec in data.get("records", []):
if isinstance(rec.get("embedding"), list):
rec["embedding"] = np.array(rec["embedding"], dtype="float32")
return data