Spaces:
Build error
Build error
| import gradio as gr | |
| import os | |
| import openai | |
| from openai import OpenAI | |
| import logging | |
| import fitz # PyMuPDF | |
| import pdfminer.high_level | |
| import docx | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from concurrent.futures import ThreadPoolExecutor | |
| logging.basicConfig(level=logging.DEBUG) | |
| os.environ["OPENAI_API_KEY"] = os.environ["gptkey"] | |
| client = OpenAI( | |
| api_key=os.environ['OPENAI_API_KEY'], # this is also the default, it can be omitted | |
| ) | |
| def generate_answer(brand_name,question, files): | |
| with ThreadPoolExecutor(max_workers=3) as executor: | |
| fun_1 = executor.submit(gpt_relevant_info,brand_name,question, files) | |
| info = fun_1.result() | |
| fun_2 = executor.submit(gpt_answer,brand_name,question, info) | |
| ans = fun_2.result() | |
| return ans,info | |
| # 函数:解析PDF文件 | |
| def extract_text_from_pdf(file_path): | |
| pdf_document = fitz.open(file_path) | |
| total_pages = pdf_document.page_count | |
| print(f"总页数: {total_pages}") | |
| # 读取文本 | |
| page = pdf_document.load_page(0) # 读取第一页 | |
| text = page.get_text("text") | |
| print(f"第一页文本:\n{text}") | |
| text = "" | |
| for page_num in range(len(pdf_document)): | |
| page = pdf_document.load_page(page_num) | |
| text += page.get_text() | |
| return text | |
| # 函数:解析DOCX文件 | |
| def extract_text_from_docx(file_path): | |
| doc = docx.Document(file_path) | |
| text = "" | |
| for paragraph in doc.paragraphs: | |
| text += paragraph.text + "\n" | |
| print("=======ppb=======") | |
| print(text) | |
| return text | |
| # 函数:解析TXT文件 | |
| def extract_text_from_txt(file_path): | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| text = f.read() | |
| print("=======ppc=======") | |
| print(text) | |
| return text | |
| # 函数:根据文件类型选择解析函数 | |
| def parse_file(file): | |
| file_path = file.name | |
| if file_path.endswith(".pdf"): | |
| return extract_text_from_pdf(file_path) | |
| elif file_path.endswith(".docx"): | |
| return extract_text_from_docx(file_path) | |
| elif file_path.endswith(".txt"): | |
| return extract_text_from_txt(file_path) | |
| else: | |
| return "Unsupported file type" | |
| # 函数:获取文本嵌入向量 | |
| def get_embedding(text, model="text-embedding-3-small"): | |
| response = client.embeddings.create(input=text, model=model) | |
| em = response.data[0].embedding | |
| print("======emm=====") | |
| print(em) | |
| return em | |
| # 函数:将长文本切片为较小段落 | |
| def split_text(text, max_length=500): | |
| sentences = text.split('.') | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| if len(current_chunk) + len(sentence) <= max_length: | |
| current_chunk += sentence + '.' | |
| else: | |
| chunks.append(current_chunk) | |
| current_chunk = sentence + '.' | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| return chunks | |
| # 函数:计算相似度并返回最相关的片段 | |
| def find_top_n_relevant_sections(input_text, file_texts , n): | |
| input_embedding = get_embedding(input_text) | |
| all_embeddings = [] | |
| all_texts = [] | |
| for text in file_texts: | |
| chunks = split_text(text) | |
| all_texts.extend(chunks) | |
| all_embeddings.extend([get_embedding(chunk) for chunk in chunks]) | |
| similarities = cosine_similarity([input_embedding], all_embeddings)[0] | |
| top_n_indices = similarities.argsort()[-n:][::-1] | |
| top_n_texts = [all_texts[i] for i in top_n_indices] | |
| return top_n_texts | |
| # 定义处理上传文件和回答的函数 | |
| def gpt_relevant_info(brand_name,question, files): | |
| file_contents = [parse_file(file) for file in files] | |
| ask = brand_name + question | |
| most_relevant_texts = find_top_n_relevant_sections(ask, file_contents,5) | |
| response = "\n\n".join(most_relevant_texts) | |
| return response | |
| def gpt_answer(brand_name,question, info): | |
| messages_base = [ | |
| {"role": "system", "content": "請扮演一個具備專業知識的商業策略分析師"} | |
| ] | |
| # Creating a prompt with a structured format for the Persona in Traditional Chinese | |
| prompt_text = f"請根據品牌:{brand_name}與你原本的知識以及參考資料:{info}來回答這個問題{question}」" | |
| messages_base.append({"role": "user", "content": prompt_text}) | |
| #for _ in range(loop): | |
| response = client.chat.completions.create( | |
| model='gpt-4o', | |
| max_tokens=4096, | |
| temperature=0, | |
| messages=messages_base | |
| ) | |
| completed_text = response.choices[0].message.content | |
| return completed_text | |
| demo = gr.Interface( | |
| fn=generate_answer, | |
| inputs=[ | |
| gr.Text(label="品牌名稱",value="Toyota"), | |
| gr.Text(label="關注問題",value="電動車展業的發展"), | |
| gr.File(label="上傳文件", file_count="multiple"), # 增加文件上傳功能 | |
| ], | |
| outputs=[ | |
| gr.Text(label="結果解釋",lines=30), | |
| gr.Text(label="向量資料",lines=10) | |
| ], | |
| title="回答助手", | |
| description="根據上傳的檔案進行回答", | |
| allow_flagging="never", ) | |
| demo.launch(share=True) | |