Spaces:
Runtime error
Runtime error
| import torch | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from transformers import AutoTokenizer, AutoModel | |
| from duckduckgo_search import ddg | |
| import time | |
| import gradio as gr | |
| import gc | |
| def best_device(): | |
| if torch.cuda.is_available(): | |
| return 'cuda' | |
| if torch.backends.mps.is_available(): | |
| return 'mps' | |
| return 'cpu' | |
| device = best_device() | |
| embeddings = HuggingFaceEmbeddings(model_name = 'GanymedeNil/text2vec-large-chinese', model_kwargs={'device': device}) | |
| local_db = FAISS.load_local('./text2vec/store', embeddings) | |
| model_name = 'THUDM/chatglm-6b-int4' | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code = True) | |
| if device == 'cuda': | |
| model = AutoModel.from_pretrained(model_name, trust_remote_code = True).half().cuda().eval() | |
| elif device == 'mps': | |
| model = AutoModel.from_pretrained(model_name, trust_remote_code = True).half().to("mps").eval() | |
| else: | |
| model = AutoModel.from_pretrained(model_name, trust_remote_code = True).float().eval() | |
| def local_query(text, top_k = 3): | |
| docs_and_scores = local_db.similarity_search_with_score(text) | |
| docs_and_scores.sort(key = lambda x : x[1]) | |
| local_content = '' | |
| count = 0 | |
| for doc in docs_and_scores: | |
| if count < top_k: | |
| local_content += doc[0].page_content.replace(' ', '') + '\n' | |
| count += 1 | |
| return local_content | |
| def web_search(text, limit = 3): | |
| web_content = '' | |
| try: | |
| results = ddg(text) | |
| if results: | |
| count = 0 | |
| for result in results: | |
| if count < limit: | |
| web_content += result['body'] + "\n" | |
| count += 1 | |
| except Exception as e: | |
| print(f"网络检索异常:{text}") | |
| return web_content | |
| def ask_question(question, local_content = '', web_content = ''): | |
| question = f'简洁和专业的来回答我的问题。\n如果你不知道答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n我的问题是:\n{question}' | |
| if len(web_content) > 0: | |
| if len(local_content) > 0: | |
| question = f'基于以下已知信息,简洁和专业的来回答我的问题。\n如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n已知信息是:\n{web_content}\n{local_content}\n我的问题是:\n{question}' | |
| else: | |
| question = f'基于以下已知信息,简洁和专业的来回答我的问题。\n如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n已知信息是:\n{web_content}\n我的问题是:\n{question}' | |
| elif len(local_content) > 0: | |
| question = f'基于以下已知信息,简洁和专业的来回答我的问题。\n如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n已知信息是:\n{local_content}\n我的问题是:\n{question}' | |
| response, history = model.chat(tokenizer, question, history = [], max_length = 10000, temperature = 0.1) | |
| return response | |
| def on_click(question, kb_types): | |
| gc.collect() | |
| if best_device() == 'cuda': | |
| torch.cuda.empty_cache() | |
| print("问题 [" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + "]: \n", question + "\n\n") | |
| local_content = '' | |
| if '结合本地数据' in kb_types: | |
| local_content = local_query(question, 2) | |
| web_content = '' | |
| if '结合网络检索' in kb_types: | |
| web_content = web_search(question, 3) | |
| result = ask_question(question, local_content, web_content) | |
| if len(local_content) > 0: | |
| if len(web_content) > 0: | |
| print('结合本地数据和网络检索 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ') | |
| else: | |
| print('结合本地数据 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ') | |
| elif len(web_content) > 0: | |
| print('结合网络检索 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ') | |
| else: | |
| print('仅用模型数据 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ') | |
| print(f'{result}\n\n----------------------------') | |
| gc.collect() | |
| if best_device() == 'cuda': | |
| torch.cuda.empty_cache() | |
| return result | |
| with gr.Blocks() as block: | |
| gr.Markdown('<center><h1>LLM问答机器人测试</h1></center>') | |
| cg_type = gr.CheckboxGroup(['结合本地数据', '结合网络检索'], label = '知识库类型(不勾选则仅用模型数据):') | |
| tb_input = gr.Textbox(label = '输入问题(本地数据只有中国历史知识):') | |
| btn = gr.Button("测试", variant = 'primary') | |
| tb_output = gr.Textbox(label = 'AI回答:') | |
| btn.click(fn = on_click, inputs = [tb_input, cg_type], outputs = tb_output) | |
| block.queue(concurrency_count = 1) | |
| block.launch() |