| |
| |
| import gradio as gr |
| from openai_utils import get_completion |
| from prompt_utils import build_prompt |
| from vectordb_utils import InMemoryVecDB |
| from pdf_utils import extract_text_from_pdf |
| from text_utils import split_text |
|
|
| vec_db = InMemoryVecDB() |
|
|
| |
| def init_db(file): |
| paragraphs = extract_text_from_pdf(file.name) |
| documents = split_text(paragraphs, 500, 100) |
| vec_db.add_documents(documents) |
|
|
| |
| def chat(user_input, chatbot, context, search_field): |
| search_results = vec_db.search(user_input, 3) |
| search_field = "\n\n".join(search_results) |
| prompt = build_prompt(info = search_results, query = user_input) |
| response = get_completion(prompt, context) |
| chatbot.append((user_input, response)) |
| context.append({'role': 'user', 'content': user_input}) |
| context.append({'role': 'assistant', 'content': response}) |
| return "", chatbot, context, search_field |
|
|
| |
| def reset_state(): |
| return [], [], "", "" |
|
|
|
|
| def main(): |
| with gr.Blocks() as demo: |
| gr.HTML("""<h1 align="center">PDF内容检索器</h1>""") |
| gr.Markdown("本demo为您提供了一个简单的界面,用于从PDF文件中提取文本,并使用OpenAI的大模型来检索相关的信息。") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| fileCtrl = gr.File(label="上传文件", file_types=[',pdf']) |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| chatbot = gr.Chatbot() |
| with gr.Column(scale=2): |
| |
| search_field = gr.Textbox(show_label=True, label="检索结果", info="本检索采用的是单一RAG方式", placeholder="空...", lines=10) |
| user_input = gr.Textbox(show_label=True, label="用户输入", placeholder="请点击此处输入...", lines=3) |
| with gr.Row(): |
| submitBtn = gr.Button("提 交", variant="primary", size="lg") |
| emptyBtn = gr.Button("清 空", size="sm", variant="secondary") |
|
|
| context = gr.State([]) |
|
|
| |
| submitBtn.click(chat, [user_input, chatbot, context, search_field], |
| [user_input, chatbot, context, search_field]) |
| |
| emptyBtn.click(reset_state, outputs=[chatbot, context, user_input, search_field]) |
|
|
| fileCtrl.upload(init_db, inputs = [fileCtrl]) |
|
|
| demo.queue().launch(share=True, server_name='0.0.0.0', server_port=8080, inbrowser=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|