yiqing111 commited on
Commit
36a7484
·
verified ·
1 Parent(s): 60cb1db

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +435 -0
app.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import hashlib
5
+ import sqlite3
6
+ import logging
7
+ from datetime import datetime
8
+ from typing import List, Dict, Any
9
+
10
+ import streamlit as st
11
+ import boto3
12
+ from botocore.exceptions import ClientError
13
+ from botocore.config import Config
14
+
15
+ # ========== 配置和常量 ==========
16
+ DB_FILE = 'audit_logs.db'
17
+ MODEL_ID = os.getenv("BEDROCK_MODEL_ID", "anthropic.claude-3-sonnet-20240229-v1:0")
18
+ DEFAULT_REGION = os.getenv("AWS_REGION", "us-east-1")
19
+
20
+ # ========== 日志配置 ==========
21
+ logger = logging.getLogger("audit_logger")
22
+ logger.setLevel(logging.INFO)
23
+ if not logger.handlers:
24
+ console = logging.StreamHandler()
25
+ console.setLevel(logging.INFO)
26
+ fmt = logging.Formatter("%(asctime)s %(levelname)-8s %(message)s",
27
+ datefmt="%Y-%m-%d %H:%M:%S")
28
+ console.setFormatter(fmt)
29
+ logger.addHandler(console)
30
+
31
+ # ========== 数据库操作 ==========
32
+ def init_db() -> None:
33
+ """初始化或迁移 audit_logs 表"""
34
+ conn = sqlite3.connect(DB_FILE)
35
+ cur = conn.cursor()
36
+
37
+ # 创建基础表
38
+ cur.execute('''
39
+ CREATE TABLE IF NOT EXISTS audit_logs (
40
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
41
+ timestamp TEXT NOT NULL,
42
+ user_id TEXT NOT NULL,
43
+ prompt TEXT NOT NULL,
44
+ response TEXT NOT NULL
45
+ )
46
+ ''')
47
+
48
+ # 检查现有列
49
+ cur.execute("PRAGMA table_info(audit_logs)")
50
+ cols = {row[1] for row in cur.fetchall()}
51
+
52
+ # 添加缺失的列
53
+ if 'latency_ms' not in cols:
54
+ cur.execute("ALTER TABLE audit_logs ADD COLUMN latency_ms REAL NOT NULL DEFAULT 0")
55
+ if 'success' not in cols:
56
+ cur.execute("ALTER TABLE audit_logs ADD COLUMN success INTEGER NOT NULL DEFAULT 1")
57
+
58
+ conn.commit()
59
+ conn.close()
60
+
61
+ def record_log(
62
+ user_id: str,
63
+ prompt: str,
64
+ response: str,
65
+ latency_ms: float,
66
+ success: bool
67
+ ) -> None:
68
+ """存储审计条目并打印到控制台"""
69
+ timestamp = datetime.utcnow().isoformat()
70
+ success_flag = 1 if success else 0
71
+
72
+ # 1) 持久化到 SQLite
73
+ with sqlite3.connect(DB_FILE) as conn:
74
+ conn.execute('''
75
+ INSERT INTO audit_logs (
76
+ timestamp, user_id, prompt, response, latency_ms, success
77
+ ) VALUES (?, ?, ?, ?, ?, ?)
78
+ ''', (timestamp, user_id, prompt, response, latency_ms, success_flag))
79
+ conn.commit()
80
+
81
+ # 2) 打印到控制台
82
+ logger.info(
83
+ f"[Audit] user={user_id!r} prompt={prompt!r} "
84
+ f"latency={latency_ms:.1f}ms success={bool(success_flag)}"
85
+ )
86
+
87
+ def get_logs(limit: int = 100) -> List[Dict[str, Any]]:
88
+ """获取最近的审计日志条目"""
89
+ try:
90
+ with sqlite3.connect(DB_FILE) as conn:
91
+ cursor = conn.cursor()
92
+ cursor.execute('''
93
+ SELECT user_id, prompt, response, latency_ms, success, timestamp
94
+ FROM audit_logs
95
+ ORDER BY id DESC
96
+ LIMIT ?
97
+ ''', (limit,))
98
+ rows = cursor.fetchall()
99
+ except sqlite3.Error as e:
100
+ logger.error(f"DB error: {e}")
101
+ return []
102
+
103
+ return [
104
+ {
105
+ "user_id": row[0],
106
+ "prompt": row[1],
107
+ "response": row[2],
108
+ "latency_ms": row[3],
109
+ "success": bool(row[4]),
110
+ "timestamp": row[5]
111
+ }
112
+ for row in rows
113
+ ]
114
+
115
+ # ========== AWS 认证 ==========
116
+ def authenticate_aws(access_key: str, secret_key: str, region: str = 'us-east-1') -> bool:
117
+ """验证 AWS 凭证"""
118
+ try:
119
+ client = boto3.client(
120
+ 'bedrock-runtime',
121
+ aws_access_key_id=access_key,
122
+ aws_secret_access_key=secret_key,
123
+ region_name=region
124
+ )
125
+ return True
126
+ except ClientError:
127
+ return False
128
+
129
+ # ========== Bedrock 客户端 ==========
130
+ def invoke_chat_model(
131
+ access_key: str,
132
+ secret_key: str,
133
+ region: str,
134
+ model_id: str,
135
+ messages: list,
136
+ anthropic_version: str = "2023-05-31",
137
+ max_tokens_to_sample: int = 512,
138
+ temperature: float = 0.7
139
+ ) -> str:
140
+ """调用 Anthropic 聊天模型"""
141
+ # 分离系统提示并准备消息列表
142
+ system_prompt = None
143
+ chat_messages = []
144
+ for m in messages:
145
+ role = m.get("role")
146
+ content = m.get("content")
147
+ if role == "system":
148
+ system_prompt = content
149
+ else:
150
+ chat_messages.append({"role": role, "content": content})
151
+
152
+ # 确保第一条消息来自用户
153
+ while chat_messages and chat_messages[0].get("role") != "user":
154
+ chat_messages.pop(0)
155
+
156
+ if not chat_messages or chat_messages[0].get("role") != "user":
157
+ raise RuntimeError("No user message found in the conversation.")
158
+
159
+ # 构建 Bedrock 请求体
160
+ ver = anthropic_version
161
+ if not ver.startswith("bedrock-"):
162
+ ver = f"bedrock-{ver}"
163
+
164
+ body = {
165
+ "anthropic_version": ver,
166
+ "max_tokens": max_tokens_to_sample,
167
+ "temperature": temperature,
168
+ "messages": chat_messages,
169
+ }
170
+ if system_prompt is not None:
171
+ body["system"] = system_prompt
172
+
173
+ body_bytes = json.dumps(body).encode("utf-8")
174
+
175
+ # 创建 Bedrock Runtime 客户端
176
+ try:
177
+ client = boto3.client(
178
+ "bedrock-runtime",
179
+ aws_access_key_id=access_key,
180
+ aws_secret_access_key=secret_key,
181
+ region_name=region,
182
+ config=Config(
183
+ retries={"max_attempts": 1},
184
+ read_timeout=60,
185
+ connect_timeout=5
186
+ )
187
+ )
188
+ except ClientError as e:
189
+ raise RuntimeError(f"Anthropic invocation failed: {e}")
190
+
191
+ try:
192
+ # 调用模型
193
+ resp = client.invoke_model(
194
+ modelId=model_id,
195
+ contentType="application/json",
196
+ accept="application/json",
197
+ body=body_bytes
198
+ )
199
+
200
+ # 读取和解析响应
201
+ raw = resp["body"].read().decode("utf-8")
202
+ data = json.loads(raw)
203
+
204
+ # Legacy completions API
205
+ if "completions" in data:
206
+ return data["completions"][0]["completion"]
207
+ if "completion" in data:
208
+ return data["completion"]
209
+
210
+ # Messages API → top-level "content" blocks
211
+ if "content" in data and isinstance(data["content"], list):
212
+ return "".join(
213
+ block.get("text", "")
214
+ for block in data["content"]
215
+ if block.get("type") == "text"
216
+ )
217
+
218
+ # Alternative: top-level "messages" list
219
+ if "messages" in data and isinstance(data["messages"], list):
220
+ for msg in data["messages"]:
221
+ if msg.get("role") == "assistant":
222
+ c = msg.get("content")
223
+ if isinstance(c, list):
224
+ return "".join(
225
+ block.get("text", "")
226
+ for block in c
227
+ if block.get("type") == "text"
228
+ )
229
+ elif isinstance(c, str):
230
+ return c
231
+ first = data["messages"][0].get("content")
232
+ return first if isinstance(first, str) else ""
233
+
234
+ raise RuntimeError(f"Unrecognized Anthropic response shape: {data}")
235
+
236
+ except ClientError as e:
237
+ raise RuntimeError(f"Anthropic invocation failed: {e}")
238
+
239
+ # ========== 聊天功能 ==========
240
+ def chat_with_bedrock(
241
+ user_id: str,
242
+ access_key: str,
243
+ secret_key: str,
244
+ region: str,
245
+ messages: list
246
+ ) -> Dict[str, Any]:
247
+ """与 Bedrock 聊天并记录日志"""
248
+ # 验证凭证
249
+ if not authenticate_aws(access_key, secret_key, region):
250
+ return {"success": False, "response": "Invalid AWS credentials"}
251
+
252
+ # 调用模型并测量性能
253
+ start = time.perf_counter()
254
+ try:
255
+ reply = invoke_chat_model(
256
+ access_key,
257
+ secret_key,
258
+ region,
259
+ MODEL_ID,
260
+ messages
261
+ )
262
+ success = True
263
+ except Exception as e:
264
+ reply = str(e)
265
+ success = False
266
+ latency_ms = (time.perf_counter() - start) * 1000
267
+
268
+ # 审计日志
269
+ last_prompt = messages[-1].get("content", "") if messages else ""
270
+ record_log(
271
+ user_id=user_id,
272
+ prompt=last_prompt,
273
+ response=reply,
274
+ latency_ms=latency_ms,
275
+ success=success
276
+ )
277
+
278
+ return {"success": success, "response": reply}
279
+
280
+ # ========== 前端功能 ==========
281
+ def get_user_id(access_key: str) -> str:
282
+ """通过哈希 AWS access key 计算稳定的用户标识符"""
283
+ return hashlib.sha256(access_key.encode("utf-8")).hexdigest()
284
+
285
+ def login_sidebar() -> None:
286
+ """渲染侧边栏登录表单"""
287
+ st.sidebar.header("🔐 AWS 认证")
288
+ access_key = st.sidebar.text_input("AWS Access Key ID", type="password")
289
+ secret_key = st.sidebar.text_input("AWS Secret Access Key", type="password")
290
+ region = st.sidebar.text_input(
291
+ "AWS Region", value=DEFAULT_REGION
292
+ )
293
+
294
+ if st.sidebar.button("登录"):
295
+ if not access_key or not secret_key:
296
+ st.sidebar.error("请输入 AWS 凭证")
297
+ return
298
+
299
+ # 准备测试请求以验证凭证
300
+ payload = {
301
+ "user_id": get_user_id(access_key),
302
+ "access_key": access_key,
303
+ "secret_key": secret_key,
304
+ "region": region,
305
+ "messages": [
306
+ {"role": "system", "content": "Authenticate"},
307
+ {"role": "user", "content": "Ping"}
308
+ ]
309
+ }
310
+
311
+ result = chat_with_bedrock(**payload)
312
+ if result["success"]:
313
+ # 存储凭证供聊天使用
314
+ st.session_state.authenticated = True
315
+ st.session_state.access_key = access_key
316
+ st.session_state.secret_key = secret_key
317
+ st.session_state.region = region
318
+ st.sidebar.success("登录成功!")
319
+ st.rerun()
320
+ else:
321
+ st.sidebar.error("认证失败:请检查您的密钥")
322
+
323
+ def send_message_and_get_reply(user_text: str) -> str:
324
+ """发送消息并获取回复"""
325
+ # 添加新的用户消息
326
+ st.session_state.messages.append({"role": "user", "content": user_text})
327
+
328
+ # 构建包含所有消息的负载
329
+ payload = {
330
+ "user_id": get_user_id(st.session_state.access_key),
331
+ "access_key": st.session_state.access_key,
332
+ "secret_key": st.session_state.secret_key,
333
+ "region": st.session_state.region,
334
+ "messages": st.session_state.messages,
335
+ }
336
+
337
+ # 调用后端
338
+ result = chat_with_bedrock(**payload)
339
+ if not result["success"]:
340
+ raise Exception(result["response"])
341
+
342
+ return result["response"]
343
+
344
+ def render_chat_interface() -> None:
345
+ """显示聊天消息并处理新的用户输入"""
346
+ # 如果是首次加载,初始化消息历史
347
+ if "messages" not in st.session_state:
348
+ st.session_state.messages = [{"role": "assistant", "content": "欢迎!我是基于 Amazon Bedrock 的聊天助手。请告诉我您需要什么帮助。"}]
349
+
350
+ # 渲染所有现有消息
351
+ for msg in st.session_state.messages:
352
+ st.chat_message(msg["role"]).write(msg["content"])
353
+
354
+ # 处理用户输入
355
+ if user_text := st.chat_input("您的消息..."):
356
+ # 立即渲染用户气泡
357
+ st.chat_message("user").write(user_text)
358
+
359
+ # 发送到后端并获取回复
360
+ try:
361
+ reply = send_message_and_get_reply(user_text)
362
+ except Exception as e:
363
+ st.error(f"聊天服务错误: {e}")
364
+ return
365
+
366
+ # 渲染助手响应
367
+ st.session_state.messages.append({"role": "assistant", "content": reply})
368
+ st.chat_message("assistant").write(reply)
369
+
370
+ def render_dashboard() -> None:
371
+ """渲染仪表板"""
372
+ st.subheader("📊 聊天记录和统计")
373
+
374
+ logs = get_logs(limit=50)
375
+ if not logs:
376
+ st.info("暂无聊天记录")
377
+ return
378
+
379
+ # 转换为 DataFrame 用于显示
380
+ import pandas as pd
381
+ df = pd.DataFrame(logs)
382
+
383
+ # 显示统计信息
384
+ col1, col2, col3, col4 = st.columns(4)
385
+ with col1:
386
+ st.metric("总对话数", len(df))
387
+ with col2:
388
+ st.metric("成功率", f"{df['success'].mean()*100:.1f}%")
389
+ with col3:
390
+ st.metric("平均延迟", f"{df['latency_ms'].mean():.1f}ms")
391
+ with col4:
392
+ st.metric("唯一用户", df['user_id'].nunique())
393
+
394
+ # 显示最近的聊天记录
395
+ st.subheader("最近的聊天记录")
396
+ for log in logs[:10]: # 显示最近10条
397
+ with st.expander(f"用户 {log['user_id'][:8]}... - {log['timestamp']}"):
398
+ st.write(f"**用户**: {log['prompt']}")
399
+ st.write(f"**助手**: {log['response']}")
400
+ st.write(f"**延迟**: {log['latency_ms']:.1f}ms")
401
+ st.write(f"**状态**: {'✅ 成功' if log['success'] else '❌ 失败'}")
402
+
403
+ # ========== 主应用 ==========
404
+ def main() -> None:
405
+ """主入口点"""
406
+ # 初始化数据库
407
+ init_db()
408
+
409
+ # 页面配置和标题
410
+ st.set_page_config(
411
+ page_title="企业级 Bedrock 聊天机器人",
412
+ page_icon="💼",
413
+ layout="wide"
414
+ )
415
+
416
+ # 侧边栏登录
417
+ login_sidebar()
418
+
419
+ # 主标题
420
+ st.title("💼 企业级 Bedrock 聊天机器人")
421
+
422
+ # 如果已认证,显示功能选择
423
+ if st.session_state.get("authenticated"):
424
+ tab1, tab2 = st.tabs(["💬 聊天", "📊 仪表板"])
425
+
426
+ with tab1:
427
+ render_chat_interface()
428
+
429
+ with tab2:
430
+ render_dashboard()
431
+ else:
432
+ st.info("请在侧边栏使用 AWS 凭证登录以开始聊天。")
433
+
434
+ if __name__ == "__main__":
435
+ main()