FeiMatrix-Synapse / core /tool_recommender.py
aifeifei798's picture
Upload 7 files
719390c verified
raw
history blame
3.05 kB
import sqlite3
import os
from pymilvus import MilvusClient
import google.generativeai as genai
from typing import List, Dict
class DirectToolRecommender:
"""
Directly uses Milvus and Google GenAI for tool recommendation.
No dependency on LlamaIndex.
"""
def __init__(self, milvus_client: MilvusClient, sqlite_db_path: str):
self.milvus_client = milvus_client
self.sqlite_db_path = sqlite_db_path
self.collection_name = "tool_embeddings"
self.embedding_model_name = "gemini-embedding-exp-03-07"
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
raise ValueError(
"Error: GEMINI_API_KEY not found. The recommender cannot function."
)
genai.configure(api_key=api_key)
print(
f"Direct Tool Recommender initialized, using embedding model: {self.embedding_model_name}."
)
def recommend_tools(self, user_query: str, top_k: int = 3) -> List[Dict]:
"""
Recommends the top_k most relevant tools based on the user query.
"""
print(f"\n[Tool Recommender] Received query: '{user_query}'")
# 1. Generate query embedding directly
result = genai.embed_content(
model=self.embedding_model_name,
content=user_query,
task_type="retrieval_query",
)
query_embedding = result["embedding"]
# 2. Search for similar tools in Milvus
search_results = self.milvus_client.search(
collection_name=self.collection_name,
data=[query_embedding],
limit=top_k,
output_fields=["id"],
)
if not search_results or not search_results[0]:
print("[Tool Recommender] No similar tools found in Milvus.")
return []
recommended_ids = [hit["id"] for hit in search_results[0]]
print(f"[Tool Recommender] Milvus recommended tool IDs: {recommended_ids}")
# 3. Get full tool metadata from SQLite and sort
with sqlite3.connect(self.sqlite_db_path) as conn:
cursor = conn.cursor()
if not recommended_ids:
return []
placeholders = ",".join("?" for _ in recommended_ids)
cursor.execute(
f"SELECT id, name, description, parameters FROM tools WHERE id IN ({placeholders})",
recommended_ids,
)
tools_metadata = cursor.fetchall()
id_to_tool_meta = {
row[0]: {"name": row[1], "description": row[2], "parameters": row[3]}
for row in tools_metadata
}
sorted_tools = [
id_to_tool_meta[tool_id]
for tool_id in recommended_ids
if tool_id in id_to_tool_meta
]
print(
f"[Tool Recommender] Final recommended tools: {[t['name'] for t in sorted_tools]}"
)
return sorted_tools