# Copyright 2024 PKU-Alignment Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """访问文本模型的命令行界面""" import argparse import os from openai import OpenAI import gradio as gr import random random.seed(42) CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) # 系统提示词,可以根据需要修改 SYSTEM_PROMPT = "你是一个有帮助的AI助手,能够回答用户的问题并提供帮助。" # 连接设置 openai_api_key = "jiayi" # 不重要,仅用于初始化客户端 aligner_port = 8013 base_port = 8011 aligner_api_base = f"http://0.0.0.0:{aligner_port}/v1" base_api_base = f"http://0.0.0.0:{base_port}/v1" # openai_api_base = "http://0.0.0.0:8009/v1" # 请修改为实际的模型API端口 # NOTE please modify the model path aligner_model = "" base_model = "" aligner_client = OpenAI( api_key = openai_api_key, base_url = aligner_api_base, ) base_client = OpenAI( api_key = openai_api_key, base_url = base_api_base, ) # 示例问题 # TEXT_EXAMPLES = [ # {"text": "介绍一下北京大学的历史"}, # {"text": "解释一下什么是深度学习"}, # {"text": "写一首关于春天的诗"}, # ] TEXT_EXAMPLES = [ "介绍一下北京大学的历史", "解释一下什么是深度学习", "写一首关于春天的诗", ] # # 初始化OpenAI客户端 # client = OpenAI( # api_key=openai_api_key, # base_url=openai_api_base, # ) def text_conversation(text: str, role: str = 'user'): """创建单条文本消息""" return [{'role': role, 'content': text}] def question_answering(message: str, history: list): """处理文本问答(流式输出)""" conversation = text_conversation(SYSTEM_PROMPT, 'system') # 处理历史对话记录 for past_user_msg, past_bot_msg in history: if past_user_msg: conversation.extend(text_conversation(past_user_msg, 'user')) if past_bot_msg: conversation.extend(text_conversation(past_bot_msg, 'assistant')) # 添加当前问题 current_question = message conversation.extend(text_conversation(current_question)) # 调用模型API(启用流式输出) stream = base_client.chat.completions.create( model=base_model, stream=True, messages=conversation, ) # 流式输出处理 total_answer = "" base_section = "🌟 **原始回答:**\n" total_answer += base_section # NOTE 额外用一个base_answer 作为aligner的输入,其他的可以用total_answer 做总的输出 base_answer = "" yield total_answer for chunk in stream: if chunk.choices[0].delta.content is not None: base_answer += chunk.choices[0].delta.content total_answer += chunk.choices[0].delta.content yield f"```bash\n{base_section}{base_answer}\n```" # 结束原始回答部分,开始aligner部分 aligner_section = "\n**Aligner 修正中...**\n\n🌟 **修正后回答:**\n" # 创建新的total_answer,不再包含在bash格式中 total_answer = f"```bash\n{base_section}{base_answer}\n```{aligner_section}" yield total_answer aligner_conversation = text_conversation(SYSTEM_PROMPT,'system') aligner_current_question = f'##Question: {current_question}\n##Answer: {base_answer}\n##Correction: ' aligner_conversation.extend(text_conversation(aligner_current_question)) aligner_stream = aligner_client.chat.completions.create( model=aligner_model, stream=True, messages=aligner_conversation, ) aligner_answer = "" for chunk in aligner_stream: if chunk.choices[0].delta.content is not None: aligner_answer += chunk.choices[0].delta.content aligner_answer = aligner_answer.replace('##CORRECTION:', '') yield f"```bash\n{base_section}{base_answer}\n```{aligner_section}{aligner_answer}" # print('answer:', answer) # print('current question:', current_question) # # 可选:格式化回答(在流式输出完成后处理) # if "**Final Answer**" in answer: # reasoning_content, final_answer = answer.split("**Final Answer**", 1) # if len(reasoning_content) > 5: # answer = f"""🤔 思考过程:\n```bash{reasoning_content}\n```\n✨ 最终答案:\n{final_answer}""" # yield answer if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=7860, help="Gradio服务端口") parser.add_argument("--share", default='True',action="store_true", help="是否创建公共链接") parser.add_argument("--api-only", default='False',action="store_true", help="只输出Python API调用示例") args = parser.parse_args() # if args.api_only: # print("Python API调用示例输出:") # print(python_api_example()) # else: # 创建Gradio界面(启用流式输出) iface = gr.ChatInterface( fn=question_answering, title='Aligner', description='网络安全 Aligner', examples=TEXT_EXAMPLES, theme=gr.themes.Soft( text_size='lg', spacing_size='lg', radius_size='lg', ), ) iface.launch(server_port=args.port, share=args.share)