Spaces:
Runtime error
Runtime error
Commit
·
e4c798e
1
Parent(s):
abab449
feat: add summerizer map-reduce
Browse files- utils/chatbot_diff.py +249 -0
- utils/gpt_processor.py +62 -21
utils/chatbot_diff.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import secrets
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
+
import openai
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from google.oauth2.service_account import Credentials
|
| 12 |
+
from googleapiclient.discovery import build
|
| 13 |
+
from googleapiclient.http import MediaIoBaseDownload, MediaFileUpload
|
| 14 |
+
from openai.embeddings_utils import distances_from_embeddings
|
| 15 |
+
|
| 16 |
+
from .gpt_processor import QuestionAnswerer
|
| 17 |
+
from .work_flow_controller import WorkFlowController
|
| 18 |
+
|
| 19 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 20 |
+
openai.api_key = OPENAI_API_KEY
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Chatbot:
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.history = []
|
| 26 |
+
self.upload_state = "waiting"
|
| 27 |
+
self.uid = self.__generate_uid()
|
| 28 |
+
|
| 29 |
+
self.g_drive_service = self.__init_drive_service()
|
| 30 |
+
self.knowledge_base = None
|
| 31 |
+
self.context = None
|
| 32 |
+
self.context_page_num = None
|
| 33 |
+
self.context_file_name = None
|
| 34 |
+
|
| 35 |
+
def build_knowledge_base(self, files, upload_mode="once"):
|
| 36 |
+
work_flow_controller = WorkFlowController(files, self.uid)
|
| 37 |
+
self.csv_result_path = work_flow_controller.csv_result_path
|
| 38 |
+
self.json_result_path = work_flow_controller.json_result_path
|
| 39 |
+
|
| 40 |
+
if upload_mode == "Upload to Database":
|
| 41 |
+
self.__get_db_knowledge_base()
|
| 42 |
+
else:
|
| 43 |
+
self.__get_local_knowledge_base()
|
| 44 |
+
|
| 45 |
+
def __get_db_knowledge_base(self):
|
| 46 |
+
filename = "knowledge_base.csv"
|
| 47 |
+
db = self.__read_db(self.g_drive_service)
|
| 48 |
+
cur_content = pd.read_csv(self.csv_result_path)
|
| 49 |
+
for _ in range(10):
|
| 50 |
+
try:
|
| 51 |
+
self.__write_into_db(self.g_drive_service, db, cur_content)
|
| 52 |
+
break
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logging.error(e)
|
| 55 |
+
logging.error("Failed to upload to database, retrying...")
|
| 56 |
+
continue
|
| 57 |
+
self.knowledge_base = db
|
| 58 |
+
self.upload_state = "done"
|
| 59 |
+
|
| 60 |
+
def __get_local_knowledge_base(self):
|
| 61 |
+
with open(self.csv_result_path, "r", encoding="UTF-8") as fp:
|
| 62 |
+
knowledge_base = pd.read_csv(fp)
|
| 63 |
+
knowledge_base["page_embedding"] = (
|
| 64 |
+
knowledge_base["page_embedding"].apply(eval).apply(np.array)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.knowledge_base = knowledge_base
|
| 68 |
+
self.upload_state = "done"
|
| 69 |
+
|
| 70 |
+
def __write_into_db(self, service, db: pd.DataFrame, cur_content: pd.DataFrame):
|
| 71 |
+
db = pd.concat([db, cur_content], ignore_index=True)
|
| 72 |
+
db.to_csv(f"{self.uid}_knowledge_base.csv", index=False)
|
| 73 |
+
media = MediaFileUpload(f"{self.uid}_knowledge_base.csv", resumable=True)
|
| 74 |
+
request = (
|
| 75 |
+
service.files()
|
| 76 |
+
.update(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW", media_body=media)
|
| 77 |
+
.execute()
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def __init_drive_service(self):
|
| 81 |
+
SCOPES = ["https://www.googleapis.com/auth/drive"]
|
| 82 |
+
SERVICE_ACCOUNT_INFO = os.getenv("CREDENTIALS")
|
| 83 |
+
service_account_info_dict = json.loads(SERVICE_ACCOUNT_INFO)
|
| 84 |
+
|
| 85 |
+
creds = Credentials.from_service_account_info(
|
| 86 |
+
service_account_info_dict, scopes=SCOPES
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return build("drive", "v3", credentials=creds)
|
| 90 |
+
|
| 91 |
+
def __read_db(self, service):
|
| 92 |
+
request = service.files().get_media(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW")
|
| 93 |
+
fh = io.BytesIO()
|
| 94 |
+
downloader = MediaIoBaseDownload(fh, request)
|
| 95 |
+
|
| 96 |
+
done = False
|
| 97 |
+
while done is False:
|
| 98 |
+
status, done = downloader.next_chunk()
|
| 99 |
+
print(f"Download {int(status.progress() * 100)}%.")
|
| 100 |
+
|
| 101 |
+
fh.seek(0)
|
| 102 |
+
|
| 103 |
+
return pd.read_csv(fh)
|
| 104 |
+
|
| 105 |
+
def __read_file(self, service, filename) -> pd.DataFrame:
|
| 106 |
+
query = f"name='{filename}'"
|
| 107 |
+
results = service.files().list(q=query).execute()
|
| 108 |
+
files = results.get("files", [])
|
| 109 |
+
|
| 110 |
+
file_id = files[0]["id"]
|
| 111 |
+
|
| 112 |
+
request = service.files().get_media(fileId=file_id)
|
| 113 |
+
fh = io.BytesIO()
|
| 114 |
+
downloader = MediaIoBaseDownload(fh, request)
|
| 115 |
+
|
| 116 |
+
done = False
|
| 117 |
+
while done is False:
|
| 118 |
+
status, done = downloader.next_chunk()
|
| 119 |
+
print(f"Download {int(status.progress() * 100)}%.")
|
| 120 |
+
|
| 121 |
+
fh.seek(0)
|
| 122 |
+
|
| 123 |
+
return pd.read_csv(fh)
|
| 124 |
+
|
| 125 |
+
def __upload_file(self, service):
|
| 126 |
+
results = service.files().list(pageSize=10).execute()
|
| 127 |
+
items = results.get("files", [])
|
| 128 |
+
if not items:
|
| 129 |
+
print("No files found.")
|
| 130 |
+
else:
|
| 131 |
+
print("Files:")
|
| 132 |
+
for item in items:
|
| 133 |
+
print(f"{item['name']} ({item['id']})")
|
| 134 |
+
|
| 135 |
+
media = MediaFileUpload(self.csv_result_path, resumable=True)
|
| 136 |
+
filename_prefix = "ex_bot_database_"
|
| 137 |
+
filename = filename_prefix + self.uid + ".csv"
|
| 138 |
+
request = (
|
| 139 |
+
service.files()
|
| 140 |
+
.create(
|
| 141 |
+
media_body=media,
|
| 142 |
+
body={
|
| 143 |
+
"name": filename,
|
| 144 |
+
"parents": [
|
| 145 |
+
"1Lp21EZlVlqL-c27VQBC6wTbUC1YpKMsG"
|
| 146 |
+
],
|
| 147 |
+
},
|
| 148 |
+
)
|
| 149 |
+
.execute()
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def clear_state(self):
|
| 153 |
+
self.context = None
|
| 154 |
+
self.context_page_num = None
|
| 155 |
+
self.context_file_name = None
|
| 156 |
+
self.knowledge_base = None
|
| 157 |
+
self.upload_state = "waiting"
|
| 158 |
+
self.history = []
|
| 159 |
+
|
| 160 |
+
def send_system_notification(self):
|
| 161 |
+
if self.upload_state == "waiting":
|
| 162 |
+
conversation = [["已上傳文件", "文件處理中(摘要、翻譯等),結束後將自動回覆"]]
|
| 163 |
+
return conversation
|
| 164 |
+
elif self.upload_state == "done":
|
| 165 |
+
conversation = [["已上傳文件", "文件處理完成,請開始提問"]]
|
| 166 |
+
return conversation
|
| 167 |
+
|
| 168 |
+
def change_md(self):
|
| 169 |
+
content = self.__construct_summary()
|
| 170 |
+
return gr.Markdown.update(content, visible=True)
|
| 171 |
+
|
| 172 |
+
def __construct_summary(self):
|
| 173 |
+
with open(self.json_result_path, "r", encoding="UTF-8") as fp:
|
| 174 |
+
knowledge_base = json.load(fp)
|
| 175 |
+
|
| 176 |
+
context = ""
|
| 177 |
+
for key in knowledge_base.keys():
|
| 178 |
+
file_name = knowledge_base[key]["file_name"]
|
| 179 |
+
total_page = knowledge_base[key]["total_pages"]
|
| 180 |
+
summary = knowledge_base[key]["summarized_content"]
|
| 181 |
+
file_context = f"""
|
| 182 |
+
### 文件摘要
|
| 183 |
+
{file_name} (共 {total_page} 頁)<br><br>
|
| 184 |
+
{summary}<br><br>
|
| 185 |
+
"""
|
| 186 |
+
context += file_context
|
| 187 |
+
return context
|
| 188 |
+
|
| 189 |
+
def user(self, message):
|
| 190 |
+
self.history += [[message, None]]
|
| 191 |
+
return "", self.history
|
| 192 |
+
|
| 193 |
+
def bot(self):
|
| 194 |
+
user_message = self.history[-1][0]
|
| 195 |
+
print(f"user_message: {user_message}")
|
| 196 |
+
|
| 197 |
+
if self.knowledge_base is None:
|
| 198 |
+
response = [
|
| 199 |
+
[user_message, "請先上傳文件"],
|
| 200 |
+
]
|
| 201 |
+
self.history = response
|
| 202 |
+
return self.history
|
| 203 |
+
|
| 204 |
+
else:
|
| 205 |
+
self.__get_index_file(user_message)
|
| 206 |
+
if self.context is None:
|
| 207 |
+
response = [
|
| 208 |
+
[user_message, "無法找到相關文件,請重新提問"],
|
| 209 |
+
]
|
| 210 |
+
self.history = response
|
| 211 |
+
return self.history
|
| 212 |
+
else:
|
| 213 |
+
qa_processor = QuestionAnswerer()
|
| 214 |
+
bot_message = qa_processor.answer_question(
|
| 215 |
+
self.context,
|
| 216 |
+
self.context_page_num,
|
| 217 |
+
self.context_file_name,
|
| 218 |
+
self.history,
|
| 219 |
+
)
|
| 220 |
+
print(f"bot_message: {bot_message}")
|
| 221 |
+
response = [
|
| 222 |
+
[user_message, bot_message],
|
| 223 |
+
]
|
| 224 |
+
self.history[-1] = response[0]
|
| 225 |
+
return self.history
|
| 226 |
+
|
| 227 |
+
def __get_index_file(self, user_message):
|
| 228 |
+
user_message_embedding = openai.Embedding.create(
|
| 229 |
+
input=user_message, engine="text-embedding-ada-002"
|
| 230 |
+
)["data"][0]["embedding"]
|
| 231 |
+
|
| 232 |
+
self.knowledge_base["distance"] = distances_from_embeddings(
|
| 233 |
+
user_message_embedding,
|
| 234 |
+
self.knowledge_base["page_embedding"].values,
|
| 235 |
+
distance_metric="cosine",
|
| 236 |
+
)
|
| 237 |
+
self.knowledge_base = self.knowledge_base.sort_values(
|
| 238 |
+
by="distance", ascending=True
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if self.knowledge_base["distance"].values[0] > 0.2:
|
| 242 |
+
self.context = None
|
| 243 |
+
else:
|
| 244 |
+
self.context = self.knowledge_base["page_content"].values[0]
|
| 245 |
+
self.context_page_num = self.knowledge_base["page_num"].values[0]
|
| 246 |
+
self.context_file_name = self.knowledge_base["file_name"].values[0]
|
| 247 |
+
|
| 248 |
+
def __generate_uid(self):
|
| 249 |
+
return secrets.token_hex(8)
|
utils/gpt_processor.py
CHANGED
|
@@ -24,38 +24,30 @@ class GPTAgent:
|
|
| 24 |
response = self.agent.complete(messages=messages)
|
| 25 |
return response.choices[0].message["content"]
|
| 26 |
|
| 27 |
-
def split_into_many(
|
| 28 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 29 |
-
# Split the text into sentences
|
| 30 |
-
sentences = text.split("。")
|
| 31 |
|
| 32 |
-
|
| 33 |
n_tokens = [len(tokenizer.encode(" " + sentence)) for sentence in sentences]
|
| 34 |
|
| 35 |
chunks = []
|
| 36 |
tokens_so_far = 0
|
| 37 |
chunk = []
|
| 38 |
|
| 39 |
-
# Loop through the sentences and tokens joined together in a tuple
|
| 40 |
for sentence, token in zip(sentences, n_tokens):
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
# the chunk and tokens so far
|
| 44 |
-
if tokens_so_far + token > self.split_max_tokens:
|
| 45 |
chunks.append("。".join(chunk) + "。")
|
| 46 |
chunk = []
|
| 47 |
tokens_so_far = 0
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
if token > self.split_max_tokens:
|
| 52 |
-
continue
|
| 53 |
-
|
| 54 |
-
# Otherwise, add the sentence to the chunk and add the number of tokens to the total
|
| 55 |
chunk.append(sentence)
|
| 56 |
tokens_so_far += token + 1
|
| 57 |
|
| 58 |
-
|
|
|
|
| 59 |
return [text] if len(chunks) == 0 else chunks
|
| 60 |
|
| 61 |
def preprocess(self, text):
|
|
@@ -202,10 +194,59 @@ class Summarizer(GPTAgent):
|
|
| 202 |
system_prompt = """
|
| 203 |
請幫我總結以下的文章。
|
| 204 |
"""
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
try:
|
| 210 |
response = openai.ChatCompletion.create(
|
| 211 |
model=self.model,
|
|
@@ -224,7 +265,7 @@ class Summarizer(GPTAgent):
|
|
| 224 |
response["choices"][0]["message"]["content"]
|
| 225 |
)
|
| 226 |
|
| 227 |
-
return
|
| 228 |
|
| 229 |
|
| 230 |
class QuestionAnswerer(GPTAgent):
|
|
|
|
| 24 |
response = self.agent.complete(messages=messages)
|
| 25 |
return response.choices[0].message["content"]
|
| 26 |
|
| 27 |
+
def split_into_many(text):
|
| 28 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
sentences = text.split("。")
|
| 31 |
n_tokens = [len(tokenizer.encode(" " + sentence)) for sentence in sentences]
|
| 32 |
|
| 33 |
chunks = []
|
| 34 |
tokens_so_far = 0
|
| 35 |
chunk = []
|
| 36 |
|
|
|
|
| 37 |
for sentence, token in zip(sentences, n_tokens):
|
| 38 |
+
|
| 39 |
+
if tokens_so_far + token > 500:
|
|
|
|
|
|
|
| 40 |
chunks.append("。".join(chunk) + "。")
|
| 41 |
chunk = []
|
| 42 |
tokens_so_far = 0
|
| 43 |
|
| 44 |
+
if token > 500:
|
| 45 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
chunk.append(sentence)
|
| 47 |
tokens_so_far += token + 1
|
| 48 |
|
| 49 |
+
chunks.append("。".join(chunk) + "。")
|
| 50 |
+
|
| 51 |
return [text] if len(chunks) == 0 else chunks
|
| 52 |
|
| 53 |
def preprocess(self, text):
|
|
|
|
| 194 |
system_prompt = """
|
| 195 |
請幫我總結以下的文章。
|
| 196 |
"""
|
| 197 |
+
|
| 198 |
+
text_chunks = self.split_into_many(text)
|
| 199 |
+
if len(text_chunks) > 1:
|
| 200 |
+
concated_summary = ""
|
| 201 |
+
for i in range(len(text_chunks)):
|
| 202 |
+
text_chunk = text[i].replace("\n", " ").replace("\r", "")
|
| 203 |
+
messages = [
|
| 204 |
+
{"role": "system", "content": f"{system_prompt}"},
|
| 205 |
+
{"role": "user", "content": text_chunk},
|
| 206 |
+
]
|
| 207 |
+
try:
|
| 208 |
+
response = openai.ChatCompletion.create(
|
| 209 |
+
model=self.model,
|
| 210 |
+
messages=messages,
|
| 211 |
+
temperature=self.temperature,
|
| 212 |
+
max_tokens=self.max_tokens,
|
| 213 |
+
frequency_penalty=self.frequency_penalty,
|
| 214 |
+
presence_penalty=self.presence_penalty,
|
| 215 |
+
)
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logging.error(e)
|
| 218 |
+
logging.error("Failed to summarize text_chunk")
|
| 219 |
+
chinese_converter = OpenCC("s2tw")
|
| 220 |
+
concated_summary += chinese_converter.convert(
|
| 221 |
+
response["choices"][0]["message"]["content"].strip()
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# summarize concated_summary
|
| 225 |
+
messages = [
|
| 226 |
+
{"role": "system", "content": f"{system_prompt}"},
|
| 227 |
+
{"role": "user", "content": concated_summary},
|
| 228 |
+
]
|
| 229 |
+
try:
|
| 230 |
+
response = openai.ChatCompletion.create(
|
| 231 |
+
model=self.model,
|
| 232 |
+
messages=messages,
|
| 233 |
+
temperature=self.temperature,
|
| 234 |
+
max_tokens=self.max_tokens,
|
| 235 |
+
frequency_penalty=self.frequency_penalty,
|
| 236 |
+
presence_penalty=self.presence_penalty,
|
| 237 |
+
)
|
| 238 |
+
except Exception as e:
|
| 239 |
+
logging.error(e)
|
| 240 |
+
logging.error("Failed to summarize concated_summary")
|
| 241 |
+
chinese_converter = OpenCC("s2tw")
|
| 242 |
+
return chinese_converter.convert(
|
| 243 |
+
response["choices"][0]["message"]["content"].strip()
|
| 244 |
+
)
|
| 245 |
+
else:
|
| 246 |
+
messages = [
|
| 247 |
+
{"role": "system", "content": f"{system_prompt}"},
|
| 248 |
+
{"role": "user", "content": text},
|
| 249 |
+
]
|
| 250 |
try:
|
| 251 |
response = openai.ChatCompletion.create(
|
| 252 |
model=self.model,
|
|
|
|
| 265 |
response["choices"][0]["message"]["content"]
|
| 266 |
)
|
| 267 |
|
| 268 |
+
return response
|
| 269 |
|
| 270 |
|
| 271 |
class QuestionAnswerer(GPTAgent):
|