| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import os |
| | import json |
| | from typing import List, Dict, Any |
| |
|
| | import numpy as np |
| | from sentence_transformers import SentenceTransformer |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | EMBEDDING_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" |
| |
|
| | _model: SentenceTransformer | None = None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def get_embedder() -> SentenceTransformer: |
| | global _model |
| | if _model is None: |
| | _model = SentenceTransformer(EMBEDDING_MODEL_NAME) |
| | return _model |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def embed_text(text: str, normalize: bool = True) -> np.ndarray: |
| | """ |
| | Embed a single piece of text. |
| | Returns a 1D numpy array (MPNet: 768-dim). |
| | """ |
| | model = get_embedder() |
| | emb = model.encode( |
| | [text], |
| | show_progress_bar=False, |
| | normalize_embeddings=normalize, |
| | ) |
| | return emb[0] |
| |
|
| |
|
| | def embed_texts(texts: List[str], normalize: bool = True) -> np.ndarray: |
| | """ |
| | Embed a list of strings -> (N, D) numpy array. |
| | """ |
| | model = get_embedder() |
| | return model.encode( |
| | texts, |
| | show_progress_bar=False, |
| | normalize_embeddings=normalize, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_kb_index(path: str = "data/rag/index/kb_index.json") -> Dict[str, Any]: |
| | """ |
| | Load the RAG knowledge base index JSON. |
| | |
| | Expected format: |
| | { |
| | "version": int, |
| | "model_name": str, |
| | "records": [ |
| | { |
| | "id": str, |
| | "genus": str, |
| | "species": str | null, |
| | "level": "genus" | "species", |
| | "chunk_id": int, |
| | "source_file": str, |
| | "text": str, |
| | "embedding": [float, ...] |
| | } |
| | ] |
| | } |
| | """ |
| | if not os.path.exists(path): |
| | raise FileNotFoundError(f"KB index not found at {path}") |
| |
|
| | with open(path, "r", encoding="utf-8") as f: |
| | data = json.load(f) |
| |
|
| | index_model = data.get("model_name") |
| | if index_model != EMBEDDING_MODEL_NAME: |
| | raise ValueError( |
| | f"KB index built with '{index_model}', " |
| | f"but current embedder is '{EMBEDDING_MODEL_NAME}'. " |
| | "Rebuild the index." |
| | ) |
| |
|
| | |
| | for rec in data.get("records", []): |
| | if isinstance(rec.get("embedding"), list): |
| | rec["embedding"] = np.array(rec["embedding"], dtype="float32") |
| |
|
| | return data |