Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import time | |
| import gradio as gr | |
| import modelscope_studio.components.antd as antd | |
| import modelscope_studio.components.antdx as antdx | |
| import modelscope_studio.components.base as ms | |
| import modelscope_studio.components.pro as pro | |
| from mem0 import Memory | |
| from modelscope_studio.components.pro.chatbot import (ChatbotBotConfig, | |
| ChatbotPromptsConfig, | |
| ChatbotUserConfig, | |
| ChatbotWelcomeConfig) | |
| from openai import OpenAI | |
| config = { | |
| "vector_store": { | |
| "provider": "faiss", | |
| "config": { | |
| "collection_name": "test", | |
| "path": "./faiss_memories", | |
| "distance_strategy": "euclidean" | |
| } | |
| } | |
| } | |
| m = Memory.from_config(config) | |
| gw_api_key = os.getenv("GW_API_KEY") | |
| client = OpenAI( | |
| base_url='https://api.geniuworks.com/v2', | |
| api_key=gw_api_key, | |
| ) | |
| model = "xinyuan-32b-v0609" | |
| # model = "gpt-4.1-2025-04-14" | |
| # 用户管理相关函数 | |
| USERS_FILE = "users.txt" | |
| def load_users(): | |
| """加载已注册用户列表""" | |
| if not os.path.exists(USERS_FILE): | |
| return set() | |
| with open(USERS_FILE, 'r', encoding='utf-8') as f: | |
| return set(line.strip() for line in f if line.strip()) | |
| def save_user(username): | |
| """保存新用户到文件""" | |
| with open(USERS_FILE, 'a', encoding='utf-8') as f: | |
| f.write(username + '\n') | |
| def is_valid_username(username): | |
| """验证用户名是否有效(仅英文字母和数字)""" | |
| if not username: | |
| return False | |
| return bool(re.match(r'^[a-zA-Z][a-zA-Z0-9_]*$', username)) and len(username) >= 3 | |
| def login_user(username): | |
| """用户登录验证""" | |
| if not is_valid_username(username): | |
| return False, "用户名无效!用户名必须以英文字母开头,只能包含英文字母、数字和下划线,且长度至少3位。" | |
| users = load_users() | |
| if username in users: | |
| return True, f"欢迎回来,{username}!" | |
| else: | |
| return False, f"用户 {username} 未注册,请先注册。" | |
| def register_user(username): | |
| """用户注册""" | |
| if not is_valid_username(username): | |
| return False, "用户名无效!用户名必须以英文字母开头,只能包含英文字母、数字和下划线,且长度至少3位。" | |
| users = load_users() | |
| if username in users: | |
| return False, f"用户名 {username} 已存在,请直接登录。" | |
| save_user(username) | |
| return True, f"注册成功!欢迎,{username}!" | |
| def handle_auth(username, is_register): | |
| """处理认证逻辑""" | |
| if is_register: | |
| success, message = register_user(username) | |
| else: | |
| success, message = login_user(username) | |
| if success: | |
| return ( | |
| gr.update(visible=False), # 隐藏登录界面 | |
| gr.update(visible=True), # 显示聊天界面 | |
| gr.update(message=message, type="success", visible=True), # 显示成功消息 | |
| username | |
| ) | |
| else: | |
| return ( | |
| gr.update(visible=True), # 保持登录界面可见 | |
| gr.update(visible=False), # 隐藏聊天界面 | |
| gr.update(message=message, type="error", visible=True), # 显示错误消息 | |
| "" | |
| ) | |
| def prompt_select(e: gr.EventData): | |
| return gr.update(value=e._data["payload"][0]["value"]["description"]) | |
| def clear(): | |
| return gr.update(value=None) | |
| def retry(chatbot_value, e: gr.EventData, username=None): | |
| index = e._data["payload"][0]["index"] | |
| chatbot_value = chatbot_value[:index] | |
| yield gr.update(loading=True), gr.update(value=chatbot_value), gr.update( | |
| disabled=True) | |
| for chunk in submit(None, chatbot_value, username): | |
| yield chunk | |
| def cancel(chatbot_value): | |
| chatbot_value[-1]["loading"] = False | |
| chatbot_value[-1]["status"] = "done" | |
| chatbot_value[-1]["footer"] = "Chat completion paused" | |
| return gr.update(value=chatbot_value), gr.update(loading=False), gr.update( | |
| disabled=False) | |
| def format_history(sender_value, history, username=None): | |
| messages = [] | |
| # 添加系统提示,包含用户名信息 | |
| if username: | |
| system_prompt = f"""You are Xinyuan, a large language model trained by Cylingo Group. You are a helpful assistant. 目前和你聊天的用户是{username}.""" | |
| messages.append({"role": "system", "content": system_prompt}) | |
| related_memories = m.search(query=sender_value, user_id=username) | |
| print(related_memories) | |
| related_memories_content = "" | |
| # {'results': [{'id': '8de25384-f210-4442-a04f-cd6c7796a5b7', 'memory': 'Loves sci-fi movies', 'hash': '1110b1af77367917ea2022355a16f187', 'metadata': None, 'score': 0.1812809524839618, 'created_at': '2025-08-05T23:54:13.694114-07:00', 'updated_at': None, 'user_id': 'alice'}, {'id': 'a4aa36b6-0595-492c-b6b1-5013511820d1', 'memory': 'Not a big fan of thriller movies', 'hash': '028dfab4483f28980e292f62578d3293', 'metadata': None, 'score': 0.17128575336629281, 'created_at': '2025-08-05T23:54:13.691791-07:00', 'updated_at': None, 'user_id': 'alice'}, {'id': 'a736ea22-3042-4275-ab9b-596324348119', 'memory': 'Planning to watch a movie tonight', 'hash': 'bf55418607cfdca4afa311b5fd8496bd', 'metadata': None, 'score': 0.1213398963070364, 'created_at': '2025-08-05T23:54:13.687585-07:00', 'updated_at': None, 'user_id': 'alice'}]} | |
| # 将related_memories按照score排序 | |
| if related_memories and 'results' in related_memories: | |
| related_memories_list = sorted(related_memories['results'], key=lambda x: x['score'], reverse=True) | |
| for id, item in enumerate(related_memories_list): | |
| # 将score添加到memory中 | |
| related_memories_content += f"相关记忆{id}:\n内容:{item['memory']}\n相关度:{item['score']}\n\n" | |
| if related_memories_content: | |
| system_prompt += f"\n相关记忆:\n{related_memories_content}" | |
| messages.insert(0, {"role": "system", "content": system_prompt}) | |
| for item in history: | |
| if item["role"] == "user": | |
| messages.append({"role": "user", "content": item["content"]}) | |
| elif item["role"] == "assistant": | |
| # ignore thought message | |
| messages.append({ | |
| "role": "assistant", | |
| "content": item["content"][-1]["content"] | |
| }) | |
| print(related_memories) | |
| print(messages) | |
| return messages | |
| def submit(sender_value, chatbot_value, username=None): | |
| if sender_value is not None: | |
| chatbot_value.append({ | |
| "role": "user", | |
| "content": sender_value, | |
| }) | |
| history_messages = format_history(sender_value, chatbot_value, username) | |
| chatbot_value.append({ | |
| "role": "assistant", | |
| "content": [], | |
| "loading": True, | |
| "status": "pending" | |
| }) | |
| yield { | |
| sender: gr.update(value=None, loading=True), | |
| clear_btn: gr.update(disabled=True), | |
| chatbot: gr.update(value=chatbot_value) | |
| } | |
| try: | |
| response = client.chat.completions.create(model=model, | |
| messages=history_messages, | |
| stream=True, | |
| max_tokens=32768, | |
| temperature=0.6, | |
| top_p=0.95, | |
| ) | |
| thought_done = False | |
| start_time = time.time() | |
| message_content = chatbot_value[-1]["content"] | |
| # thought content | |
| message_content.append({ | |
| "copyable": False, | |
| "editable": False, | |
| "type": "tool", | |
| "content": "", | |
| "options": { | |
| "title": "Thinking..." | |
| } | |
| }) | |
| # content | |
| message_content.append({ | |
| "type": "text", | |
| "content": "", | |
| }) | |
| # 收集完整的助手响应内容用于保存到内存 | |
| full_assistant_content = "" | |
| for chunk in response: | |
| try: | |
| reasoning_content = chunk.choices[0].delta.reasoning_content | |
| except: | |
| reasoning_content = "" | |
| try: | |
| content = chunk.choices[0].delta.content | |
| except: | |
| content = "" | |
| chatbot_value[-1]["loading"] = False | |
| message_content[-2]["content"] += reasoning_content or "" | |
| message_content[-1]["content"] += content or "" | |
| # 收集助手的实际回复内容(不包括思考过程) | |
| if content: | |
| full_assistant_content += content | |
| if content and not thought_done: | |
| thought_done = True | |
| thought_cost_time = "{:.2f}".format(time.time() - start_time) | |
| message_content[-2]["options"][ | |
| "title"] = f"End of Thought ({thought_cost_time}s)" | |
| message_content[-2]["options"]["status"] = "done" | |
| yield {chatbot: gr.update(value=chatbot_value)} | |
| # 在流式响应完成后保存到内存 | |
| if username and sender_value and full_assistant_content: | |
| memory_messages = [ | |
| {'role': 'user', 'content': sender_value}, | |
| {'role': 'assistant', 'content': full_assistant_content} | |
| ] | |
| m.add(memory_messages, user_id=username) | |
| chatbot_value[-1]["footer"] = "{:.2f}".format(time.time() - | |
| start_time) + 's' | |
| chatbot_value[-1]["status"] = "done" | |
| yield { | |
| clear_btn: gr.update(disabled=False), | |
| sender: gr.update(loading=False), | |
| chatbot: gr.update(value=chatbot_value), | |
| } | |
| except Exception as e: | |
| chatbot_value[-1]["loading"] = False | |
| chatbot_value[-1]["status"] = "done" | |
| chatbot_value[-1]["content"] = "Failed to respond, please try again." | |
| yield { | |
| clear_btn: gr.update(disabled=False), | |
| sender: gr.update(loading=False), | |
| chatbot: gr.update(value=chatbot_value), | |
| } | |
| raise e | |
| with gr.Blocks() as demo, ms.Application(), antdx.XProvider(): | |
| # 状态变量 | |
| current_user = gr.State("") | |
| # 登录界面 | |
| with antd.Flex(vertical=True, gap="large", elem_id="login_container") as login_container: | |
| with antd.Card(title="欢迎使用 Xinyuan 聊天助手"): | |
| with antd.Flex(vertical=True, gap="middle"): | |
| antd.Typography.Title("用户登录/注册", level=3) | |
| antd.Typography.Text("请输入您的英文用户名(3位以上,仅支持英文字母、数字和下划线)") | |
| username_input = antd.Input( | |
| placeholder="请输入用户名(如:john_doe)", | |
| size="large" | |
| ) | |
| with antd.Flex(gap="small"): | |
| login_btn = antd.Button("登录", type="primary", size="large") | |
| register_btn = antd.Button("注册", size="large") | |
| auth_message = antd.Alert( | |
| message="请输入用户名", | |
| type="info", | |
| visible=False | |
| ) | |
| # 聊天界面 | |
| with antd.Flex(vertical=True, gap="middle", visible=False) as chat_container: | |
| # 用户信息栏 | |
| with antd.Flex(justify="space-between", align="center"): | |
| user_info = gr.Markdown("") | |
| logout_btn = antd.Button("退出登录", size="small") | |
| chatbot = pro.Chatbot( | |
| height=1000, | |
| welcome_config=ChatbotWelcomeConfig( | |
| variant="borderless", | |
| icon="./xinyuan.png", | |
| title=f"Hello, I'm Xinyuan👋", | |
| description="You can input text to get started.", | |
| prompts=ChatbotPromptsConfig( | |
| title="How can I help you today?", | |
| styles={ | |
| "list": { | |
| "width": '100%', | |
| }, | |
| "item": { | |
| "flex": 1, | |
| }, | |
| }, | |
| items=[{ | |
| "label": | |
| "💝 心理学与实际应用", | |
| "children": [{ | |
| "description": | |
| "课题分离是什么意思?" | |
| }, { | |
| "description": | |
| "回避型依恋和焦虑型依恋有什么区别?还有其他依恋类型吗?" | |
| }, { | |
| "description": | |
| "为什么我背单词的时候总是只记得开头和结尾,中间全忘了?" | |
| }] | |
| }, { | |
| "label": | |
| "👪 儿童教育与发展", | |
| "children": [{ | |
| "description": | |
| "什么是正念养育?" | |
| }, { | |
| "description": | |
| "2岁孩子分离焦虑严重,送托育中心天天哭闹怎么办?" | |
| }, { | |
| "description": | |
| "4岁娃说话不清还爱打人,是心理问题还是欠管教?" | |
| }] | |
| }])), | |
| user_config=ChatbotUserConfig( | |
| avatar="https://api.dicebear.com/7.x/miniavs/svg?seed=3", | |
| variant="shadow"), | |
| bot_config=ChatbotBotConfig( | |
| header='Xinyuan', | |
| avatar="./xinyuan.png", | |
| actions=["copy", "retry"], | |
| variant="shadow"), | |
| ) | |
| with antdx.Sender() as sender: | |
| with ms.Slot("prefix"): | |
| with antd.Button(value=None, color="default", | |
| variant="text") as clear_btn: | |
| with ms.Slot("icon"): | |
| antd.Icon("ClearOutlined") | |
| # 事件绑定 | |
| def handle_login(username): | |
| return handle_auth(username, False) | |
| def handle_register(username): | |
| return handle_auth(username, True) | |
| def handle_logout(): | |
| return ( | |
| gr.update(visible=True), # 显示登录界面 | |
| gr.update(visible=False), # 隐藏聊天界面 | |
| gr.update(message="已退出登录", type="info", visible=True), | |
| gr.update(value=""), # 清空用户名输入 | |
| "", # 清空用户信息显示 | |
| "" # 清空当前用户状态 | |
| ) | |
| def update_user_info(username): | |
| if username: | |
| return f"**当前用户: {username}**" | |
| return "" | |
| # 登录按钮事件 | |
| login_btn.click( | |
| fn=handle_login, | |
| inputs=[username_input], | |
| outputs=[login_container, chat_container, auth_message, current_user] | |
| ).then( | |
| fn=update_user_info, | |
| inputs=[current_user], | |
| outputs=[user_info] | |
| ) | |
| # 注册按钮事件 | |
| register_btn.click( | |
| fn=handle_register, | |
| inputs=[username_input], | |
| outputs=[login_container, chat_container, auth_message, current_user] | |
| ).then( | |
| fn=update_user_info, | |
| inputs=[current_user], | |
| outputs=[user_info] | |
| ) | |
| # 退出登录按钮事件 | |
| logout_btn.click( | |
| fn=handle_logout, | |
| outputs=[login_container, chat_container, auth_message, username_input, user_info, current_user] | |
| ) | |
| # 聊天功能事件绑定 | |
| clear_btn.click(fn=clear, outputs=[chatbot]) | |
| submit_event = sender.submit(fn=submit, | |
| inputs=[sender, chatbot, current_user], | |
| outputs=[sender, chatbot, clear_btn]) | |
| sender.cancel(fn=cancel, | |
| inputs=[chatbot], | |
| outputs=[chatbot, sender, clear_btn], | |
| cancels=[submit_event], | |
| queue=False) | |
| chatbot.retry(fn=retry, | |
| inputs=[chatbot, current_user], | |
| outputs=[sender, chatbot, clear_btn]) | |
| chatbot.welcome_prompt_select(fn=prompt_select, outputs=[sender]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |