File size: 5,092 Bytes
d1eee25
 
 
 
 
c164738
9d22697
 
 
 
8f660bb
9d22697
d1eee25
 
 
d9b734e
 
 
1a54f9a
 
 
 
 
 
 
 
 
21558a0
1a54f9a
9d22697
 
c7b621c
 
 
 
 
 
 
 
 
 
c164738
c7b621c
3af6820
c164738
9d22697
 
 
 
 
 
 
 
47c2ca7
 
9d22697
 
 
 
 
 
47c2ca7
 
 
9d22697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7583d3
8800100
 
 
 
 
9d22697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47c2ca7
9d22697
 
 
 
 
 
 
 
 
47c2ca7
 
 
 
 
9d22697
 
1a54f9a
9d22697
1a54f9a
 
 
9d22697
1a54f9a
9d22697
 
1a54f9a
d1eee25
1a54f9a
d1eee25
 
 
1a54f9a
d1eee25
 
 
 
1a54f9a
d1eee25
 
 
 
 
 
 
283e6cb
d1eee25
 
b6138bd
d1eee25
1a54f9a
 
7e0654c
d1eee25
 
 
951e86d
d1eee25
7e0654c
 
d1eee25
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import gradio as gr
import os
import openai
from openai import OpenAI
import logging
import fitz  # PyMuPDF
import pdfminer.high_level
import docx
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from concurrent.futures import ThreadPoolExecutor

logging.basicConfig(level=logging.DEBUG)
os.environ["OPENAI_API_KEY"] = os.environ["gptkey"]

client = OpenAI(
  api_key=os.environ['OPENAI_API_KEY'],  # this is also the default, it can be omitted
)

def generate_answer(brand_name,question, files):
    with ThreadPoolExecutor(max_workers=3) as executor:
        fun_1 = executor.submit(gpt_relevant_info,brand_name,question, files)
        info = fun_1.result()
            
        fun_2 = executor.submit(gpt_answer,brand_name,question, info)
        ans = fun_2.result()

    return ans,info
    
# 函数:解析PDF文件
def extract_text_from_pdf(file_path):
    pdf_document = fitz.open(file_path)

    total_pages = pdf_document.page_count
    print(f"总页数: {total_pages}")
 
    # 读取文本
    page = pdf_document.load_page(0)  # 读取第一页
    text = page.get_text("text")
    print(f"第一页文本:\n{text}")

    text = ""
    for page_num in range(len(pdf_document)):
        page = pdf_document.load_page(page_num)
        text += page.get_text()
    return text

# 函数:解析DOCX文件
def extract_text_from_docx(file_path):
    doc = docx.Document(file_path)
    text = ""
    for paragraph in doc.paragraphs:
        text += paragraph.text + "\n"
    print("=======ppb=======")
    print(text)
    return text

# 函数:解析TXT文件
def extract_text_from_txt(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        text = f.read()

    print("=======ppc=======")
    print(text)
    return text

# 函数:根据文件类型选择解析函数
def parse_file(file):
    file_path = file.name
    if file_path.endswith(".pdf"):
        return extract_text_from_pdf(file_path)
    elif file_path.endswith(".docx"):
        return extract_text_from_docx(file_path)
    elif file_path.endswith(".txt"):
        return extract_text_from_txt(file_path)
    else:
        return "Unsupported file type"

# 函数:获取文本嵌入向量
def get_embedding(text, model="text-embedding-3-small"):
    response = client.embeddings.create(input=text, model=model)
    em = response.data[0].embedding

    print("======emm=====")
    print(em)
    return em

# 函数:将长文本切片为较小段落
def split_text(text, max_length=500):
    sentences = text.split('.')
    chunks = []
    current_chunk = ""
    for sentence in sentences:
        if len(current_chunk) + len(sentence) <= max_length:
            current_chunk += sentence + '.'
        else:
            chunks.append(current_chunk)
            current_chunk = sentence + '.'
    if current_chunk:
        chunks.append(current_chunk)
    return chunks

# 函数:计算相似度并返回最相关的片段
def find_top_n_relevant_sections(input_text, file_texts , n):
    input_embedding = get_embedding(input_text)
    all_embeddings = []
    all_texts = []

    for text in file_texts:
        chunks = split_text(text)
        all_texts.extend(chunks)
        all_embeddings.extend([get_embedding(chunk) for chunk in chunks])

    similarities = cosine_similarity([input_embedding], all_embeddings)[0]
    top_n_indices = similarities.argsort()[-n:][::-1]
    top_n_texts = [all_texts[i] for i in top_n_indices]
    
    return top_n_texts

# 定义处理上传文件和回答的函数
def gpt_relevant_info(brand_name,question, files):
    file_contents = [parse_file(file) for file in files]

    ask = brand_name + question
    most_relevant_texts = find_top_n_relevant_sections(ask, file_contents,5)
    
    response = "\n\n".join(most_relevant_texts)
    return response

def gpt_answer(brand_name,question, info):
    messages_base = [
        {"role": "system", "content": "請扮演一個具備專業知識的商業策略分析師"}
    ]

    # Creating a prompt with a structured format for the Persona in Traditional Chinese
    prompt_text = f"請根據品牌:{brand_name}與你原本的知識以及參考資料:{info}來回答這個問題{question}」"
    messages_base.append({"role": "user", "content": prompt_text})

    #for _ in range(loop):
    response = client.chat.completions.create(
       model='gpt-4o',
       max_tokens=4096,
       temperature=0,
       messages=messages_base
     )
    completed_text = response.choices[0].message.content
    

    return completed_text

demo = gr.Interface(
    fn=generate_answer,
    inputs=[
        gr.Text(label="品牌名稱",value="Toyota"),
        gr.Text(label="關注問題",value="電動車展業的發展"),
        gr.File(label="上傳文件", file_count="multiple"),  # 增加文件上傳功能
    ],
    outputs=[
        gr.Text(label="結果解釋",lines=30), 
        gr.Text(label="向量資料",lines=10) 
    ],
    title="回答助手",
    description="根據上傳的檔案進行回答",
    allow_flagging="never", )
demo.launch(share=True)