mathagent / scripts /build_weaviate_kb.py
kaushik1064's picture
Add backend FastAPI code
886572e
"""Script to build Weaviate vector store from knowledge base dataset."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import List, Dict, Any
import numpy as np
import weaviate
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from app.config import settings
def load_dataset(path: Path) -> List[Dict[str, Any]]:
data: List[Dict[str, Any]] = []
for line in path.read_text(encoding="utf-8").splitlines():
if line.strip():
data.append(json.loads(line))
return data
def get_weaviate_client() -> weaviate.Client:
"""Initialize Weaviate client with authentication."""
auth_config = weaviate.auth.AuthApiKey(api_key=settings.weaviate_api_key)
client = weaviate.Client(
url=settings.weaviate_url,
auth_client_secret=auth_config,
)
return client
def setup_schema(client: weaviate.Client) -> None:
"""Create the schema if it doesn't exist."""
if not client.schema.exists(settings.weaviate_class_name):
class_obj = {
"class": settings.weaviate_class_name,
"vectorizer": "none", # We'll provide our own vectors
"properties": [
{"name": "question", "dataType": ["text"]},
{"name": "answer", "dataType": ["text"]},
{"name": "source", "dataType": ["text"]},
]
}
client.schema.create_class(class_obj)
def build_index(client: weaviate.Client, records: List[Dict[str, Any]], batch_size: int = 100) -> None:
"""Build the vector index by importing records in batches."""
encoder = SentenceTransformer(settings.embedding_model_name)
with client.batch as batch:
batch.batch_size = batch_size
for record in tqdm(records, desc="Importing records"):
# Generate embedding from question and answer
text = record["question"] + "\n" + record.get("answer", "")
embedding = encoder.encode(text)
# Normalize the embedding
embedding = embedding / np.linalg.norm(embedding)
# Prepare properties
properties = {
"question": record["question"],
"answer": record.get("answer", ""),
"source": record.get("source", "")
}
# Import the object with its vector
batch.add_data_object(
data_object=properties,
class_name=settings.weaviate_class_name,
vector=embedding.tolist()
)
def main(dataset_path: Path) -> None:
print(f"Loading dataset from {dataset_path}")
records = load_dataset(dataset_path)
print("Initializing Weaviate client")
client = get_weaviate_client()
print("Setting up schema")
setup_schema(client)
print("Building index")
build_index(client, records)
print(f"Successfully imported {len(records)} records to Weaviate")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Build Weaviate vector store for math agent")
parser.add_argument("--dataset", type=Path, default=Path("backend/data/knowledge_base.jsonl"))
args = parser.parse_args()
main(args.dataset)