aifeifei798 commited on
Commit
7cae5d6
·
verified ·
1 Parent(s): e0335b9

Upload setup.py

Browse files
Files changed (1) hide show
  1. database/setup.py +159 -135
database/setup.py CHANGED
@@ -1,150 +1,174 @@
1
- import os
2
- import sqlite3
 
3
  import json
4
- from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType
5
- import google.generativeai as genai
 
 
 
 
 
 
 
6
 
7
- from tools.tool_registry import get_all_tools
 
8
 
9
- # --- 配置持久化路径 ---
10
- DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data"))
11
- SQLITE_DB_PATH = os.path.join(DATA_DIR, "tools.metadata.db")
12
- MILVUS_DATA_PATH = os.path.join(DATA_DIR, "milvus_lite.db")
 
13
 
14
- # --- 模型配置 ---
15
- EMBEDDING_DIM = 3072
16
- EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07"
17
- MILVUS_COLLECTION_NAME = "tool_embeddings"
18
 
 
 
19
 
20
- def initialize_system():
21
- print("--- 开始系统初始化 (最终通关版) ---")
22
- os.makedirs(DATA_DIR, exist_ok=True)
23
 
24
- # --- 正确的初始化顺序 ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # 1. 初始化SQLite并同步工具元数据
27
- # 确保SQLite里总是有最新的工具信息
28
- _init_sqlite_db()
29
- all_tools_definitions = get_all_tools()
30
- _sync_tools_to_sqlite(all_tools_definitions)
 
31
 
32
- # 2. 初始化Milvus并同步向量
33
- # 它会从已经填充好的SQLite中读取数据
34
- milvus_client = _init_milvus_and_sync_embeddings()
35
 
36
- # 3. 创建工具推荐器
37
- from core.tool_recommender import DirectToolRecommender
 
38
 
39
- tool_recommender = DirectToolRecommender(
40
- milvus_client=milvus_client, sqlite_db_path=SQLITE_DB_PATH
41
- )
 
42
 
43
- print("--- 系统初始化完成 ---")
44
- return all_tools_definitions, tool_recommender
45
 
 
 
 
 
 
46
 
47
- def _init_sqlite_db():
48
- print(f"SQLite DB 路径: {SQLITE_DB_PATH}")
49
- with sqlite3.connect(SQLITE_DB_PATH) as conn:
50
- cursor = conn.cursor()
51
- cursor.execute(
52
- """
53
- CREATE TABLE IF NOT EXISTS tools (
54
- id INTEGER PRIMARY KEY AUTOINCREMENT,
55
- name TEXT UNIQUE NOT NULL,
56
- description TEXT NOT NULL,
57
- parameters TEXT NOT NULL
58
- )
59
- """
60
- )
61
- conn.commit()
62
- print("SQLite DB 表已确认存在。")
63
-
64
-
65
- def _sync_tools_to_sqlite(tools_definitions):
66
- print("正在同步工具元数据到SQLite...")
67
- with sqlite3.connect(SQLITE_DB_PATH) as conn:
68
- cursor = conn.cursor()
69
- for tool in tools_definitions:
70
- cursor.execute("SELECT id FROM tools WHERE name = ?", (tool.name,))
71
- if cursor.fetchone() is None:
72
- cursor.execute(
73
- "INSERT INTO tools (name, description, parameters) VALUES (?, ?, ?)",
74
- (tool.name, tool.description, json.dumps(tool.args)),
75
  )
76
- print(f" - 新增工具到SQLite: {tool.name}")
77
- conn.commit()
78
- print("SQLite同步完成。")
79
-
80
-
81
- def _init_milvus_and_sync_embeddings():
82
- print(f"Milvus Lite 数据路径: {MILVUS_DATA_PATH}")
83
- client = MilvusClient(uri=MILVUS_DATA_PATH)
84
-
85
- # 每次启动都重新创建集合,确保维度正确且数据最新
86
- if client.has_collection(collection_name=MILVUS_COLLECTION_NAME):
87
- client.drop_collection(collection_name=MILVUS_COLLECTION_NAME)
88
- print("发现旧的Milvus集合,已删除以重建。")
89
-
90
- print(f"Milvus集合 '{MILVUS_COLLECTION_NAME}' 正在创建,维度为 {EMBEDDING_DIM}...")
91
- fields = [
92
- FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
93
- FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM),
94
- ]
95
- schema = CollectionSchema(fields)
96
- client.create_collection(collection_name=MILVUS_COLLECTION_NAME, schema=schema)
97
-
98
- index_params = client.prepare_index_params()
99
- index_params.add_index(
100
- field_name="embedding", index_type="AUTOINDEX", metric_type="L2"
101
- )
102
- client.create_index(
103
- collection_name=MILVUS_COLLECTION_NAME, index_params=index_params
104
- )
105
- print("Milvus集合和索引创建完成。")
106
-
107
- # --- 关键:现在我们才同步嵌入 ---
108
- _sync_tool_embeddings_to_milvus(client)
109
-
110
- client.load_collection(collection_name=MILVUS_COLLECTION_NAME)
111
- return client
112
-
113
-
114
- def _sync_tool_embeddings_to_milvus(milvus_client):
115
- print("正在同步工具嵌入到Milvus...")
116
- api_key = os.environ.get("GEMINI_API_KEY")
117
- if not api_key:
118
- print("错误:无法找到GEMINI_API_KEY。")
119
- return
120
- genai.configure(api_key=api_key)
121
-
122
- with sqlite3.connect(SQLITE_DB_PATH) as conn:
123
- cursor = conn.cursor()
124
- cursor.execute("SELECT id, description FROM tools")
125
- all_tools_in_db = cursor.fetchall()
126
-
127
- if not all_tools_in_db:
128
- print("SQLite中没有工具可同步,这是一个错误!")
129
- return
130
-
131
- print(f"从SQLite发现 {len(all_tools_in_db)} 个工具,正在生成嵌入...")
132
- docs_to_embed = [tool[1] for tool in all_tools_in_db]
133
-
134
- print(f"使用嵌入模型: {EMBEDDING_MODEL_NAME}")
135
- result = genai.embed_content(
136
- model=EMBEDDING_MODEL_NAME,
137
- content=docs_to_embed,
138
- task_type="retrieval_document",
139
- )
140
-
141
- embeddings = result["embedding"]
142
- tool_ids_to_insert = [tool[0] for tool in all_tools_in_db]
143
-
144
- data_to_insert = [
145
- {"id": tool_id, "embedding": embedding}
146
- for tool_id, embedding in zip(tool_ids_to_insert, embeddings)
147
- ]
148
-
149
- milvus_client.insert(collection_name=MILVUS_COLLECTION_NAME, data=data_to_insert)
150
- print(f"成功将 {len(data_to_insert)} 个新嵌入插入到Milvus。")
 
1
+ from langchain_google_genai import ChatGoogleGenerativeAI
2
+ from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
3
+ from typing import List, Any
4
  import json
5
+ import os
6
+ import re # 导入正则表达式库
7
+
8
+ from .tool_recommender import DirectToolRecommender
9
+ from tools.tool_registry import get_tool_by_name
10
+
11
+ # Agent的思考模板 (保持不变)
12
+ AGENT_PROMPT_TEMPLATE = """
13
+ 你是一个强大的AI助理。你的任务是理解用户的问题,并决定是否需要使用工具来回答。
14
 
15
+ 你有以下工具可用:
16
+ {tools}
17
 
18
+ 如果需要使用工具,请严格按照以下JSON格式进行响应,不要包含任何其他文本或解释:
19
+ {{
20
+ "tool": "要调用的工具名称",
21
+ "tool_input": {{ "参数1": "值1", "参数2": "值2" }}
22
+ }}
23
 
24
+ 如果不需要使用任何工具,请直接回答用户的问题。
 
 
 
25
 
26
+ 这是对话历史:
27
+ {chat_history}
28
 
29
+ 用户问题:{input}
 
 
30
 
31
+ 现在,请你思考并作出回应(JSON或直接回答):
32
+ """
33
+
34
+
35
+ class SmartAIAgent:
36
+ def __init__(
37
+ self,
38
+ tool_recommender: DirectToolRecommender,
39
+ registered_tools: List[Any],
40
+ api_key: str,
41
+ ):
42
+ self.tool_recommender = tool_recommender
43
+ self.registered_tools = registered_tools
44
+ self.model_name = "gemini-2.5-flash"
45
+ self.llm = ChatGoogleGenerativeAI(
46
+ model=self.model_name,
47
+ google_api_key=api_key,
48
+ convert_system_message_to_human=True,
49
+ )
50
+ self.chat_history = []
51
+ print(f"LangChain Agent已初始化,使用模型: {self.model_name}。")
52
+
53
+ # ------------------- 核心修复在这里! -------------------
54
+ # 我们添加一个更健壮的JSON提取函数
55
+ def _extract_json_from_string(self, text: str) -> dict | None:
56
+ """从可能包含其他文本的字符串中提取出JSON块。"""
57
+ # 匹配被 markdown 包裹的JSON
58
+ match = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL)
59
+ if match:
60
+ json_str = match.group(1)
61
+ else:
62
+ # 匹配裸露的JSON
63
+ match = re.search(r"\{.*\}", text, re.DOTALL)
64
+ if match:
65
+ json_str = match.group(0)
66
+ else:
67
+ return None
68
+
69
+ try:
70
+ return json.loads(json_str)
71
+ except json.JSONDecodeError:
72
+ return None
73
+
74
+ # ----------------------------------------------------
75
+
76
+ def _format_tools_for_prompt(self, tools: List[dict]) -> str:
77
+ # ... (此函数保持不变) ...
78
+ if not tools:
79
+ return "没有可用的工具。"
80
+ tool_strings = []
81
+ for tool in tools:
82
+ try:
83
+ params = json.loads(tool["parameters"])
84
+ param_str = ", ".join(
85
+ [f"{p_name}: {p_type}" for p_name, p_type in params.items()]
86
+ )
87
+ tool_strings.append(
88
+ f"- 工具名称: {tool['name']}\n - 描述: {tool['description']}\n - 参数: {param_str}"
89
+ )
90
+ except (json.JSONDecodeError, TypeError):
91
+ tool_strings.append(
92
+ f"- 工具名称: {tool['name']}\n - 描述: {tool['description']}\n - 参数: 无法解析"
93
+ )
94
+ return "\n".join(tool_strings)
95
+
96
+ def _format_chat_history(self) -> str:
97
+ # ... (此函数保持不变) ...
98
+ formatted_history = []
99
+ for msg in self.chat_history:
100
+ if isinstance(msg, HumanMessage):
101
+ formatted_history.append(f"用户: {msg.content}")
102
+ elif isinstance(msg, AIMessage):
103
+ formatted_history.append(f"助理: {msg.content}")
104
+ elif isinstance(msg, ToolMessage):
105
+ formatted_history.append(f"工具结果: {msg.content}")
106
+ return "\n".join(formatted_history)
107
+
108
+ def stream_run(self, user_input: str):
109
+ self.chat_history.append(HumanMessage(content=user_input))
110
+ yield "🤔 正在分析您的问题...\n"
111
+
112
+ yield "🔍 正在从工具库中推荐相关工具...\n"
113
+ recommended_tools_meta = self.tool_recommender.recommend_tools(user_input)
114
+
115
+ if not recommended_tools_meta:
116
+ yield "ℹ️ 未找到相关工具,将直接回答。\n"
117
+ recommended_tools_prompt = "没有推荐的工具。"
118
+ else:
119
+ tool_names = [t["name"] for t in recommended_tools_meta]
120
+ yield f"✅ 推荐工具: `{', '.join(tool_names)}`\n"
121
+ recommended_tools_prompt = self._format_tools_for_prompt(
122
+ recommended_tools_meta
123
+ )
124
 
125
+ yield f"🧠 正在让AI大脑({self.model_name})决定如何行动...\n"
126
+ prompt = AGENT_PROMPT_TEMPLATE.format(
127
+ tools=recommended_tools_prompt,
128
+ chat_history=self._format_chat_history(),
129
+ input=user_input,
130
+ )
131
 
132
+ llm_response = self.llm.invoke(prompt)
133
+ llm_decision_content = llm_response.content.strip()
 
134
 
135
+ # ------------------- 核心修复在这里! -------------------
136
+ # 使用我们新的、更健壮的JSON提取逻辑
137
+ decision = self._extract_json_from_string(llm_decision_content)
138
 
139
+ if decision and "tool" in decision and "tool_input" in decision:
140
+ # 如果成功提取出有效的工具调用JSON
141
+ tool_name = decision.get("tool")
142
+ tool_input = decision.get("tool_input")
143
 
144
+ yield f"💡 AI决策:调用工具 `{tool_name}`,参数为 `{tool_input}`\n"
 
145
 
146
+ tool_to_execute = get_tool_by_name(tool_name)
147
+ if tool_to_execute:
148
+ yield f"⚙️ 正在执行工具 `{tool_name}`...\n"
149
+ tool_output = tool_to_execute.invoke(tool_input)
150
+ yield f"📊 工具返回结果:\n---\n{str(tool_output)[:500]}...\n---\n"
151
 
152
+ self.chat_history.append(
153
+ AIMessage(content=json.dumps(decision, ensure_ascii=False))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  )
155
+ self.chat_history.append(
156
+ ToolMessage(content=str(tool_output), tool_call_id="N/A")
157
+ )
158
+
159
+ yield "✍️ 正在根据工具结果生成最终回答...\n\n"
160
+ final_answer_prompt = f"基于以下对话历史和最新的工具结果,请为用户生成一个最终的、完整的、自然的回答。\n\n对话历史:\n{self._format_chat_history()}\n\n请直接回答,不要提及你的思考过程。"
161
+ final_answer_stream = self.llm.stream(final_answer_prompt)
162
+ full_final_answer = ""
163
+ for chunk in final_answer_stream:
164
+ yield chunk.content
165
+ full_final_answer += chunk.content
166
+ self.chat_history.append(AIMessage(content=full_final_answer))
167
+ else:
168
+ yield f"❌ 错误:AI决策调用的工具 `{tool_name}` 不存在。\n"
169
+ else:
170
+ # 如果没有提取出JSON,或者JSON格式不正确,则认为是直接回答
171
+ yield " AI决策:直接回答。\n\n"
172
+ yield llm_decision_content
173
+ self.chat_history.append(AIMessage(content=llm_decision_content))
174
+ # ----------------------------------------------------