Spaces:
Sleeping
Sleeping
| import sqlite3 | |
| import os | |
| from pymilvus import MilvusClient | |
| import google.generativeai as genai | |
| from typing import List, Dict | |
| class DirectToolRecommender: | |
| """ | |
| Directly uses Milvus and Google GenAI for tool recommendation. | |
| No dependency on LlamaIndex. | |
| """ | |
| def __init__(self, milvus_client: MilvusClient, sqlite_db_path: str): | |
| self.milvus_client = milvus_client | |
| self.sqlite_db_path = sqlite_db_path | |
| self.collection_name = "tool_embeddings" | |
| self.embedding_model_name = "gemini-embedding-exp-03-07" | |
| api_key = os.environ.get("GEMINI_API_KEY") | |
| if not api_key: | |
| raise ValueError( | |
| "Error: GEMINI_API_KEY not found. The recommender cannot function." | |
| ) | |
| genai.configure(api_key=api_key) | |
| print( | |
| f"Direct Tool Recommender initialized, using embedding model: {self.embedding_model_name}." | |
| ) | |
| def recommend_tools(self, user_query: str, top_k: int = 3) -> List[Dict]: | |
| """ | |
| Recommends the top_k most relevant tools based on the user query. | |
| """ | |
| print(f"\n[Tool Recommender] Received query: '{user_query}'") | |
| # 1. Generate query embedding directly | |
| result = genai.embed_content( | |
| model=self.embedding_model_name, | |
| content=user_query, | |
| task_type="retrieval_query", | |
| ) | |
| query_embedding = result["embedding"] | |
| # 2. Search for similar tools in Milvus | |
| search_results = self.milvus_client.search( | |
| collection_name=self.collection_name, | |
| data=[query_embedding], | |
| limit=top_k, | |
| output_fields=["id"], | |
| ) | |
| if not search_results or not search_results[0]: | |
| print("[Tool Recommender] No similar tools found in Milvus.") | |
| return [] | |
| recommended_ids = [hit["id"] for hit in search_results[0]] | |
| print(f"[Tool Recommender] Milvus recommended tool IDs: {recommended_ids}") | |
| # 3. Get full tool metadata from SQLite and sort | |
| with sqlite3.connect(self.sqlite_db_path) as conn: | |
| cursor = conn.cursor() | |
| if not recommended_ids: | |
| return [] | |
| placeholders = ",".join("?" for _ in recommended_ids) | |
| cursor.execute( | |
| f"SELECT id, name, description, parameters FROM tools WHERE id IN ({placeholders})", | |
| recommended_ids, | |
| ) | |
| tools_metadata = cursor.fetchall() | |
| id_to_tool_meta = { | |
| row[0]: {"name": row[1], "description": row[2], "parameters": row[3]} | |
| for row in tools_metadata | |
| } | |
| sorted_tools = [ | |
| id_to_tool_meta[tool_id] | |
| for tool_id in recommended_ids | |
| if tool_id in id_to_tool_meta | |
| ] | |
| print( | |
| f"[Tool Recommender] Final recommended tools: {[t['name'] for t in sorted_tools]}" | |
| ) | |
| return sorted_tools | |