Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import gradio as gr | |
| import requests | |
| from dotenv import load_dotenv | |
| import logging | |
| # 设置日志 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # 加载环境变量 | |
| load_dotenv() | |
| # API配置 | |
| BASE_URL = os.getenv("API_URL", "").rstrip('/') # 移除末尾的斜杠 | |
| API_KEY = os.getenv("API_KEY", "") | |
| # 构建完整的API URL | |
| API_URL = f"{BASE_URL}/v1/chat/completions" | |
| # 验证环境变量 | |
| if not BASE_URL or not API_KEY: | |
| raise ValueError(""" | |
| 请确保设置了必要的环境变量: | |
| - API_URL: API基础地址 (例如: https://api.example.com) | |
| - API_KEY: API密钥 | |
| 可以在Hugging Face Space的Settings -> Repository Secrets中设置这些变量 | |
| """) | |
| class ChatBot: | |
| def __init__(self): | |
| # 修正 headers 的设置 | |
| self.headers = { | |
| "Authorization": f"Bearer {API_KEY}", # 正确使用 API_KEY | |
| "Content-Type": "application/json", | |
| "Accept": "text/event-stream" | |
| } | |
| self.verify_api_config() | |
| def verify_api_config(self): | |
| try: | |
| # 使用 OPTIONS 请求来验证API端点 | |
| response = requests.options(API_URL, timeout=5) | |
| logger.info(f"API endpoint: {API_URL}") | |
| logger.info(f"API headers: {self.headers}") | |
| if response.status_code >= 400: | |
| logger.error(f"API配置可能有误: {response.status_code}") | |
| logger.error(f"API响应: {response.text[:200]}") | |
| except Exception as e: | |
| logger.error(f"API连接测试失败: {str(e)}") | |
| def format_message(role, content): | |
| return {"role": role, "content": content} | |
| class ChatBot: | |
| def __init__(self): | |
| self.headers = { | |
| "Authorization": f"Bearer {API_KEY}", | |
| "Content-Type": "application/json", | |
| "Accept": "text/event-stream" # 添加这行 | |
| } | |
| def chat_stream(message, history): | |
| chatbot = ChatBot() | |
| logger.info(f"Sending message: {message}") | |
| logger.info(f"API URL: {API_URL}") | |
| logger.info(f"Headers: {chatbot.headers}") # 添加header日志 | |
| messages = [] | |
| for human, assistant in history: | |
| messages.append(format_message("user", human)) | |
| messages.append(format_message("assistant", assistant)) | |
| messages.append(format_message("user", message)) | |
| try: | |
| # 首先验证API是否可用 | |
| verify_response = requests.get(API_URL) | |
| logger.info(f"API验证响应状态码: {verify_response.status_code}") | |
| logger.info(f"API验证响应内容: {verify_response.text[:200]}...") # 只记录前200个字符 | |
| payload = { | |
| "model": "gpt-4o", | |
| "messages": messages, | |
| "stream": True, | |
| "temperature": 0.7 | |
| } | |
| logger.info(f"发送请求数据: {json.dumps(payload, ensure_ascii=False)}") | |
| response = requests.post( | |
| API_URL, | |
| headers=chatbot.headers, | |
| json=payload, | |
| stream=True | |
| ) | |
| if response.headers.get('content-type', '').startswith('text/html'): | |
| error_msg = "API返回了HTML而不是预期的流式响应。请检查API配置。" | |
| logger.error(error_msg) | |
| history.append((message, error_msg)) | |
| return history | |
| if response.status_code != 200: | |
| error_msg = f"API返回错误状态码: {response.status_code}\n错误信息: {response.text}" | |
| logger.error(error_msg) | |
| history.append((message, error_msg)) | |
| return history | |
| partial_message = "" | |
| for line in response.iter_lines(): | |
| if line: | |
| try: | |
| line = line.decode('utf-8') | |
| logger.info(f"收到数据: {line}") | |
| if line.startswith('data: '): | |
| line = line[6:] | |
| if line == '[DONE]': | |
| break | |
| try: | |
| chunk = json.loads(line) | |
| if chunk and "choices" in chunk: | |
| delta = chunk["choices"][0]["delta"] | |
| if "content" in delta: | |
| content = delta["content"] | |
| partial_message += content | |
| history.append((message, partial_message)) | |
| yield history | |
| history.pop() | |
| except json.JSONDecodeError as e: | |
| logger.error(f"JSON解析错误: {e}") | |
| continue | |
| except Exception as e: | |
| logger.error(f"处理响应时出错: {e}") | |
| continue | |
| if not partial_message: | |
| error_msg = "未能获取到有效的响应内容" | |
| logger.error(error_msg) | |
| history.append((message, error_msg)) | |
| else: | |
| history.append((message, partial_message)) | |
| except Exception as e: | |
| error_msg = f"请求发生错误: {str(e)}" | |
| logger.error(error_msg) | |
| history.append((message, error_msg)) | |
| return history | |
| # Gradio界面配置 | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| chatbot = gr.Chatbot( | |
| height=600, | |
| show_copy_button=True, | |
| avatar_images=["assets/user.png", "assets/assistant.png"], | |
| ) | |
| msg = gr.Textbox( | |
| placeholder="在这里输入您的问题...", | |
| container=False, | |
| scale=7, | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button("发送", scale=2, variant="primary") | |
| clear = gr.Button("清除对话", scale=1) | |
| # 事件处理 | |
| msg.submit( | |
| chat_stream, | |
| [msg, chatbot], | |
| [chatbot], | |
| api_name="chat" | |
| ).then( | |
| lambda: "", | |
| None, | |
| [msg], | |
| api_name="clear_input" | |
| ) | |
| submit.click( | |
| chat_stream, | |
| [msg, chatbot], | |
| [chatbot], | |
| ).then( | |
| lambda: "", | |
| None, | |
| [msg], | |
| ) | |
| clear.click(lambda: [], None, chatbot) | |
| # 启动应用 | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True) |