Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| import io | |
| import time | |
| import streamlit as st | |
| from PIL import Image | |
| from service import Service | |
| """ | |
| 使用 mistralai 官方库的 Service 类处理 API 请求 | |
| """ | |
| # 设置页面配置 - 必须是第一个Streamlit命令 | |
| st.set_page_config( | |
| page_title="Mistral 聊天助手", | |
| page_icon="🤖", | |
| layout="wide", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| # 初始化API服务 | |
| service = Service() | |
| # 初始化会话状态 | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "image_data" not in st.session_state: | |
| st.session_state.image_data = None | |
| def encode_image_to_base64(image): | |
| """将图像转换为 base64 字符串""" | |
| if image is None: | |
| return None | |
| try: | |
| # 如果是PIL图像 | |
| if isinstance(image, Image.Image): | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return f"data:image/png;base64,{img_str}" | |
| # 如果是字节流或文件上传对象 | |
| elif hasattr(image, 'read') or isinstance(image, bytes): | |
| if hasattr(image, 'read'): | |
| image_bytes = image.read() | |
| else: | |
| image_bytes = image | |
| img_str = base64.b64encode(image_bytes).decode("utf-8") | |
| return f"data:image/png;base64,{img_str}" | |
| # 如果是文件路径 | |
| elif isinstance(image, str) and os.path.isfile(image): | |
| with open(image, "rb") as img_file: | |
| img_str = base64.b64encode(img_file.read()).decode("utf-8") | |
| return f"data:image/png;base64,{img_str}" | |
| else: | |
| st.error(f"不支持的图像类型: {type(image)}") | |
| return None | |
| except Exception as e: | |
| st.error(f"编码图像时出错: {str(e)}") | |
| return None | |
| def read_file_content(file_path): | |
| """提取文件内容""" | |
| if file_path is None: | |
| return None | |
| try: | |
| print(f"尝试读取文件内容: {file_path}") | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| # 文本文件扩展名列表 | |
| text_exts = ['.txt', '.md', '.py', '.js', '.html', '.css', '.json', '.csv', '.xml', '.yaml', '.yml', '.ini', '.conf'] | |
| if file_ext in text_exts: | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| print(f"成功读取文件内容,长度: {len(content)}") | |
| return content | |
| except UnicodeDecodeError: | |
| # 尝试使用其他编码 | |
| try: | |
| with open(file_path, 'r', encoding='gbk') as f: | |
| content = f.read() | |
| print(f"使用GBK编码成功读取文件内容,长度: {len(content)}") | |
| return content | |
| except: | |
| print(f"无法解码文件内容,可能是二进制文件") | |
| return f"无法读取文件内容,文件可能是二进制格式或使用了不支持的编码。" | |
| else: | |
| return f"文件类型 {file_ext} 暂不支持直接读取内容,但我可以尝试分析文件名称。" | |
| except Exception as e: | |
| print(f"读取文件时出错: {str(e)}") | |
| return f"读取文件时出错: {str(e)}" | |
| def respond( | |
| message, | |
| history, | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| image=None | |
| ): | |
| try: | |
| print(f"响应函数收到:message={message[:50]}...(已截断), 图片={image is not None}") | |
| # 准备完整的消息历史 | |
| messages = [{"role": "system", "content": system_message}] | |
| # 添加历史消息 | |
| for msg in history: | |
| if msg["role"] == "user": | |
| messages.append({"role": "user", "content": msg["content"]}) | |
| elif msg["role"] == "assistant": | |
| messages.append({"role": "assistant", "content": msg["content"]}) | |
| # 设置模型和参数 | |
| service.model = "mistral-small-latest" # 可以根据需要修改为其他模型 | |
| # 处理带图像的请求 | |
| if image is not None: | |
| print("处理带图像的请求...") | |
| # 使用 chat_with_image 方法处理多模态请求 | |
| response = service.chat_with_image( | |
| text_prompt=message if message else "请分析这张图片", | |
| image_base64=image, | |
| history=messages | |
| ) | |
| print("图像请求已发送到API") | |
| else: | |
| print("处理纯文本请求...") | |
| # 纯文本请求,添加用户消息并获取响应 | |
| messages.append({"role": "user", "content": message}) | |
| response = service.get_response(messages) | |
| # 返回响应结果 | |
| print(f"API返回响应: {response[:50]}...(已截断)") | |
| return response | |
| except Exception as e: | |
| print(f"API 请求错误: {str(e)}") | |
| return f"处理请求时出错: {str(e)}" | |
| # 加载系统提示 | |
| def load_system_prompt(): | |
| return """你是一个有帮助的AI助手,可以回答用户的问题,也可以分析用户上传的图片。 | |
| 如果用户上传了图片,请详细描述图片内容,并回答用户关于图片的问题。 | |
| 如果用户没有上传图片,请正常回答用户的文本问题。 | |
| """ | |
| # 获取API响应 | |
| def get_api_response(prompt, image_data=None): | |
| try: | |
| # 准备消息历史(不包括最新的用户消息) | |
| messages = [] | |
| # 添加系统消息 | |
| messages.append({"role": "system", "content": load_system_prompt()}) | |
| # 添加历史消息 | |
| for msg in st.session_state.messages: | |
| if msg["role"] != "system": # 跳过系统消息,因为我们已经添加了 | |
| messages.append({"role": msg["role"], "content": msg["content"]}) | |
| # 处理带图像的请求 | |
| if image_data: | |
| st.info("正在处理图像...") | |
| # 使用 chat_with_image 方法处理多模态请求 | |
| return service.chat_with_image( | |
| text_prompt=prompt if prompt else "请分析这张图片", | |
| image_base64=image_data, | |
| history=messages | |
| ) | |
| else: | |
| # 添加最新的用户消息 | |
| messages.append({"role": "user", "content": prompt}) | |
| # 纯文本请求 | |
| return service.get_response(messages) | |
| except Exception as e: | |
| st.error(f"API 请求错误: {str(e)}") | |
| return f"处理请求时出错: {str(e)}" | |
| # 显示标题和说明 | |
| st.title("🤖 Mistral 多模态聊天助手") | |
| st.markdown(""" | |
| ### 使用说明 | |
| - 输入文字问题并按回车发送 | |
| - 点击"📋 粘贴图片"按钮,然后粘贴剪贴板中的图片 | |
| - 也可以使用"📎 上传图片"上传本地图片文件 | |
| - 图片和文字可以一起发送,或单独发送 | |
| """) | |
| # 创建两列布局 | |
| col1, col2 = st.columns([3, 1]) | |
| with col2: | |
| st.subheader("选项") | |
| # 添加图片上传按钮 | |
| uploaded_file = st.file_uploader("📎 上传图片", type=["jpg", "jpeg", "png"], key="file_uploader") | |
| # 粘贴图片按钮 | |
| if st.button("📋 粘贴图片"): | |
| st.session_state.paste_mode = True | |
| # 粘贴模式激活时显示粘贴区域 | |
| if "paste_mode" in st.session_state and st.session_state.paste_mode: | |
| st.markdown("### 粘贴图片区域") | |
| st.markdown("按 Ctrl+V 粘贴图片") | |
| # 使用实验性功能接收粘贴的图片 | |
| pasted_image = st.camera_input("粘贴的图片会显示在这里", key="camera") | |
| if pasted_image: | |
| st.session_state.image_data = encode_image_to_base64(pasted_image) | |
| st.session_state.paste_mode = False | |
| st.experimental_rerun() | |
| # 如果通过文件上传器上传了图片 | |
| if uploaded_file: | |
| st.session_state.image_data = encode_image_to_base64(uploaded_file) | |
| st.image(uploaded_file, caption="已上传的图片", use_column_width=True) | |
| # 清除图片按钮 | |
| if st.session_state.image_data and st.button("🗑️ 清除图片"): | |
| st.session_state.image_data = None | |
| st.experimental_rerun() | |
| # 清除对话按钮 | |
| if st.button("🧹 清除对话"): | |
| st.session_state.messages = [] | |
| st.session_state.image_data = None | |
| st.experimental_rerun() | |
| with col1: | |
| # 显示聊天历史 | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| # 显示消息内容 | |
| st.markdown(message["content"]) | |
| # 如果消息包含图片 | |
| if "image" in message and message["image"]: | |
| st.image(message["image"], use_column_width=True) | |
| # 显示当前上传的图片预览 | |
| if st.session_state.image_data: | |
| with st.expander("📷 当前图片预览", expanded=True): | |
| # 从base64解码图片以显示预览 | |
| if "base64" in st.session_state.image_data: | |
| image_b64 = st.session_state.image_data.split(",")[1] | |
| image_bytes = base64.b64decode(image_b64) | |
| st.image(image_bytes, caption="即将发送的图片", use_column_width=True) | |
| # 用户输入 | |
| prompt = st.chat_input("输入您的问题...", key="user_input") | |
| # 处理用户输入 | |
| if prompt: | |
| # 添加用户消息到历史 | |
| user_message = {"role": "user", "content": prompt} | |
| if st.session_state.image_data: | |
| user_message["image"] = st.session_state.image_data | |
| st.session_state.messages.append(user_message) | |
| # 显示用户消息 | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| if st.session_state.image_data: | |
| # 从base64解码图片以显示预览 | |
| if "base64" in st.session_state.image_data: | |
| image_b64 = st.session_state.image_data.split(",")[1] | |
| image_bytes = base64.b64decode(image_b64) | |
| st.image(image_bytes, use_column_width=True) | |
| # 显示助手思考中的状态 | |
| with st.chat_message("assistant"): | |
| with st.spinner("思考中..."): | |
| # 获取API响应 | |
| response = get_api_response(prompt, st.session_state.image_data) | |
| # 显示响应 | |
| message_placeholder = st.empty() | |
| full_response = "" | |
| # 模拟流式响应 | |
| for chunk in response.split(): | |
| full_response += chunk + " " | |
| message_placeholder.markdown(full_response + "▌") | |
| time.sleep(0.01) | |
| message_placeholder.markdown(full_response) | |
| # 添加助手响应到历史 | |
| st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
| # 清除当前图片数据,防止重复使用 | |
| st.session_state.image_data = None | |
| # 重新运行页面以更新UI | |
| st.experimental_rerun() | |
| if __name__ == "__main__": | |
| # 从环境变量获取 API 密钥,或者提示用户设置 | |
| api_key = os.environ.get("MISTRAL_API_KEY", "") | |
| if not api_key: | |
| st.sidebar.warning("未设置 MISTRAL_API_KEY 环境变量。请设置环境变量或在代码中直接设置密钥。") | |
| api_key = st.sidebar.text_input("输入您的 Mistral API 密钥:", type="password") | |
| # 设置 API 密钥 | |
| if api_key: | |
| service.headers = {"Authorization": f"Bearer {api_key}"} | |
| st.sidebar.success("API密钥已配置") | |
| else: | |
| st.sidebar.error("请设置 Mistral API 密钥以继续使用") | |