LiveRAG_QA / app.py
Zaious's picture
Update app.py
3af6820 verified
import gradio as gr
import os
import openai
from openai import OpenAI
import logging
import fitz # PyMuPDF
import pdfminer.high_level
import docx
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from concurrent.futures import ThreadPoolExecutor
logging.basicConfig(level=logging.DEBUG)
os.environ["OPENAI_API_KEY"] = os.environ["gptkey"]
client = OpenAI(
api_key=os.environ['OPENAI_API_KEY'], # this is also the default, it can be omitted
)
def generate_answer(brand_name,question, files):
with ThreadPoolExecutor(max_workers=3) as executor:
fun_1 = executor.submit(gpt_relevant_info,brand_name,question, files)
info = fun_1.result()
fun_2 = executor.submit(gpt_answer,brand_name,question, info)
ans = fun_2.result()
return ans,info
# 函数:解析PDF文件
def extract_text_from_pdf(file_path):
pdf_document = fitz.open(file_path)
total_pages = pdf_document.page_count
print(f"总页数: {total_pages}")
# 读取文本
page = pdf_document.load_page(0) # 读取第一页
text = page.get_text("text")
print(f"第一页文本:\n{text}")
text = ""
for page_num in range(len(pdf_document)):
page = pdf_document.load_page(page_num)
text += page.get_text()
return text
# 函数:解析DOCX文件
def extract_text_from_docx(file_path):
doc = docx.Document(file_path)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
print("=======ppb=======")
print(text)
return text
# 函数:解析TXT文件
def extract_text_from_txt(file_path):
with open(file_path, "r", encoding="utf-8") as f:
text = f.read()
print("=======ppc=======")
print(text)
return text
# 函数:根据文件类型选择解析函数
def parse_file(file):
file_path = file.name
if file_path.endswith(".pdf"):
return extract_text_from_pdf(file_path)
elif file_path.endswith(".docx"):
return extract_text_from_docx(file_path)
elif file_path.endswith(".txt"):
return extract_text_from_txt(file_path)
else:
return "Unsupported file type"
# 函数:获取文本嵌入向量
def get_embedding(text, model="text-embedding-3-small"):
response = client.embeddings.create(input=text, model=model)
em = response.data[0].embedding
print("======emm=====")
print(em)
return em
# 函数:将长文本切片为较小段落
def split_text(text, max_length=500):
sentences = text.split('.')
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= max_length:
current_chunk += sentence + '.'
else:
chunks.append(current_chunk)
current_chunk = sentence + '.'
if current_chunk:
chunks.append(current_chunk)
return chunks
# 函数:计算相似度并返回最相关的片段
def find_top_n_relevant_sections(input_text, file_texts , n):
input_embedding = get_embedding(input_text)
all_embeddings = []
all_texts = []
for text in file_texts:
chunks = split_text(text)
all_texts.extend(chunks)
all_embeddings.extend([get_embedding(chunk) for chunk in chunks])
similarities = cosine_similarity([input_embedding], all_embeddings)[0]
top_n_indices = similarities.argsort()[-n:][::-1]
top_n_texts = [all_texts[i] for i in top_n_indices]
return top_n_texts
# 定义处理上传文件和回答的函数
def gpt_relevant_info(brand_name,question, files):
file_contents = [parse_file(file) for file in files]
ask = brand_name + question
most_relevant_texts = find_top_n_relevant_sections(ask, file_contents,5)
response = "\n\n".join(most_relevant_texts)
return response
def gpt_answer(brand_name,question, info):
messages_base = [
{"role": "system", "content": "請扮演一個具備專業知識的商業策略分析師"}
]
# Creating a prompt with a structured format for the Persona in Traditional Chinese
prompt_text = f"請根據品牌:{brand_name}與你原本的知識以及參考資料:{info}來回答這個問題{question}」"
messages_base.append({"role": "user", "content": prompt_text})
#for _ in range(loop):
response = client.chat.completions.create(
model='gpt-4o',
max_tokens=4096,
temperature=0,
messages=messages_base
)
completed_text = response.choices[0].message.content
return completed_text
demo = gr.Interface(
fn=generate_answer,
inputs=[
gr.Text(label="品牌名稱",value="Toyota"),
gr.Text(label="關注問題",value="電動車展業的發展"),
gr.File(label="上傳文件", file_count="multiple"), # 增加文件上傳功能
],
outputs=[
gr.Text(label="結果解釋",lines=30),
gr.Text(label="向量資料",lines=10)
],
title="回答助手",
description="根據上傳的檔案進行回答",
allow_flagging="never", )
demo.launch(share=True)