aifeifei798's picture
Upload 7 files
719390c verified
raw
history blame
6.12 kB
import os
import sqlite3
import json
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType
import google.generativeai as genai
from tools.tool_registry import get_all_tools
# --- Configuration for persistence paths ---
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data"))
SQLITE_DB_PATH = os.path.join(DATA_DIR, "tools.metadata.db")
MILVUS_DATA_PATH = os.path.join(DATA_DIR, "milvus_lite.db")
# --- Model and DB Configuration ---
EMBEDDING_DIM = 3072
EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07"
MILVUS_COLLECTION_NAME = "tool_embeddings"
def initialize_system():
"""
The main system initialization function.
It creates directories, sets up the database and vector store, and loads tools.
This function is designed to be idempotent.
"""
print("--- Starting System Initialization (Final Version) ---")
os.makedirs(DATA_DIR, exist_ok=True)
# --- Correct Initialization Order ---
# 1. Initialize SQLite and sync tool metadata
# Ensures SQLite always has the latest tool information
_init_sqlite_db()
all_tools_definitions = get_all_tools()
_sync_tools_to_sqlite(all_tools_definitions)
# 2. Initialize Milvus and sync vector embeddings
# It reads data from the already populated SQLite DB
milvus_client = _init_milvus_and_sync_embeddings()
# 3. Create the tool recommender instance
from core.tool_recommender import DirectToolRecommender
tool_recommender = DirectToolRecommender(
milvus_client=milvus_client, sqlite_db_path=SQLITE_DB_PATH
)
print("--- System Initialization Complete ---")
return all_tools_definitions, tool_recommender
def _init_sqlite_db():
"""Initializes the SQLite database and creates the tools table if it doesn't exist."""
print(f"SQLite DB Path: {SQLITE_DB_PATH}")
with sqlite3.connect(SQLITE_DB_PATH) as conn:
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS tools (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
description TEXT NOT NULL,
parameters TEXT NOT NULL
)
"""
)
conn.commit()
print("SQLite DB table verified.")
def _sync_tools_to_sqlite(tools_definitions):
"""Syncs tool definitions into the SQLite database."""
print("Syncing tool metadata to SQLite...")
with sqlite3.connect(SQLITE_DB_PATH) as conn:
cursor = conn.cursor()
for tool in tools_definitions:
cursor.execute("SELECT id FROM tools WHERE name = ?", (tool.name,))
if cursor.fetchone() is None:
cursor.execute(
"INSERT INTO tools (name, description, parameters) VALUES (?, ?, ?)",
(tool.name, tool.description, json.dumps(tool.args)),
)
print(f" - Added new tool to SQLite: {tool.name}")
conn.commit()
print("SQLite sync complete.")
def _init_milvus_and_sync_embeddings():
"""Initializes Milvus Lite, rebuilds the collection, and syncs embeddings."""
print(f"Milvus Lite Data Path: {MILVUS_DATA_PATH}")
client = MilvusClient(uri=MILVUS_DATA_PATH)
# Recreate the collection on every startup to ensure correct dimensionality and fresh data for the demo.
if client.has_collection(collection_name=MILVUS_COLLECTION_NAME):
client.drop_collection(collection_name=MILVUS_COLLECTION_NAME)
print("Found old Milvus collection. Dropped it to rebuild.")
print(
f"Creating Milvus collection '{MILVUS_COLLECTION_NAME}' with dimension {EMBEDDING_DIM}..."
)
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM),
]
schema = CollectionSchema(fields)
client.create_collection(collection_name=MILVUS_COLLECTION_NAME, schema=schema)
index_params = client.prepare_index_params()
index_params.add_index(
field_name="embedding", index_type="AUTOINDEX", metric_type="L2"
)
client.create_index(
collection_name=MILVUS_COLLECTION_NAME, index_params=index_params
)
print("Milvus collection and index created successfully.")
# Critical Step: Now we sync the embeddings to the newly created collection
_sync_tool_embeddings_to_milvus(client)
client.load_collection(collection_name=MILVUS_COLLECTION_NAME)
return client
def _sync_tool_embeddings_to_milvus(milvus_client):
"""Generates and syncs tool description embeddings to Milvus Lite."""
print("Syncing tool embeddings to Milvus...")
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
print("Error: GEMINI_API_KEY not found.")
return
genai.configure(api_key=api_key)
with sqlite3.connect(SQLITE_DB_PATH) as conn:
cursor = conn.cursor()
cursor.execute("SELECT id, description FROM tools")
all_tools_in_db = cursor.fetchall()
if not all_tools_in_db:
print("Error: No tools found in SQLite to sync.")
return
print(f"Found {len(all_tools_in_db)} tools from SQLite, generating embeddings...")
docs_to_embed = [tool[1] for tool in all_tools_in_db]
print(f"Using embedding model: {EMBEDDING_MODEL_NAME}")
result = genai.embed_content(
model=EMBEDDING_MODEL_NAME,
content=docs_to_embed,
task_type="retrieval_document",
)
embeddings = result["embedding"]
tool_ids_to_insert = [tool[0] for tool in all_tools_in_db]
data_to_insert = [
{"id": tool_id, "embedding": embedding}
for tool_id, embedding in zip(tool_ids_to_insert, embeddings)
]
milvus_client.insert(collection_name=MILVUS_COLLECTION_NAME, data=data_to_insert)
print(f"Successfully inserted {len(data_to_insert)} new embeddings into Milvus.")