Spaces:
Runtime error
Runtime error
Commit ·
c88c1d9
1
Parent(s): 1beaddf
feat/add g-drive coontection
Browse files- app.py +9 -4
- utils/chatbot.py +112 -7
- utils/work_flow_controller.py +12 -10
app.py
CHANGED
|
@@ -30,7 +30,11 @@ with gr.Blocks() as demo:
|
|
| 30 |
|
| 31 |
with gr.Row():
|
| 32 |
index_file = gr.File(
|
| 33 |
-
file_count="multiple", file_types=["pdf"], label="Upload PDF file"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
)
|
| 35 |
|
| 36 |
with gr.Row():
|
|
@@ -42,7 +46,8 @@ with gr.Blocks() as demo:
|
|
| 42 |
3. 可以根據下方的摘要內容來提問
|
| 43 |
4. 每次對話會根據第一個問題的內容來檢索所有文件,並挑選最能回答問題的文件來回覆
|
| 44 |
5. 要切換檢索的文件,請點選「清除」按鈕後再重新提問
|
| 45 |
-
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
with gr.Row():
|
|
@@ -80,6 +85,7 @@ with gr.Blocks() as demo:
|
|
| 80 |
**bot_args
|
| 81 |
).then(lambda: gr.update(interactive=True), None, [user_input], queue=False)
|
| 82 |
|
|
|
|
| 83 |
# defining workflow of clear state
|
| 84 |
clear_state_args = dict(
|
| 85 |
fn=clear_state,
|
|
@@ -98,7 +104,7 @@ with gr.Blocks() as demo:
|
|
| 98 |
|
| 99 |
bulid_knowledge_base_args = dict(
|
| 100 |
fn=build_knowledge_base,
|
| 101 |
-
inputs=[user_chatbot, index_file],
|
| 102 |
outputs=None,
|
| 103 |
)
|
| 104 |
|
|
@@ -118,6 +124,5 @@ with gr.Blocks() as demo:
|
|
| 118 |
|
| 119 |
video_text_input.submit(video_bot, [test_video_chabot, video_text_input], video_text_output, api_name="video_bot")
|
| 120 |
|
| 121 |
-
|
| 122 |
if __name__ == "__main__":
|
| 123 |
demo.launch()
|
|
|
|
| 30 |
|
| 31 |
with gr.Row():
|
| 32 |
index_file = gr.File(
|
| 33 |
+
file_count="multiple", file_types=["pdf"], label="Upload PDF file", scale=3
|
| 34 |
+
)
|
| 35 |
+
upload_to_db = gr.CheckboxGroup(
|
| 36 |
+
["Upload to Database"],
|
| 37 |
+
label="是否上傳至資料庫", info="將資料上傳至資料庫時,資料庫會自動建立索引,下次使用時可以直接檢索,預設為僅作這次使用", scale=1
|
| 38 |
)
|
| 39 |
|
| 40 |
with gr.Row():
|
|
|
|
| 46 |
3. 可以根據下方的摘要內容來提問
|
| 47 |
4. 每次對話會根據第一個問題的內容來檢索所有文件,並挑選最能回答問題的文件來回覆
|
| 48 |
5. 要切換檢索的文件,請點選「清除」按鈕後再重新提問
|
| 49 |
+
|
| 50 |
+
""",
|
| 51 |
)
|
| 52 |
|
| 53 |
with gr.Row():
|
|
|
|
| 85 |
**bot_args
|
| 86 |
).then(lambda: gr.update(interactive=True), None, [user_input], queue=False)
|
| 87 |
|
| 88 |
+
|
| 89 |
# defining workflow of clear state
|
| 90 |
clear_state_args = dict(
|
| 91 |
fn=clear_state,
|
|
|
|
| 104 |
|
| 105 |
bulid_knowledge_base_args = dict(
|
| 106 |
fn=build_knowledge_base,
|
| 107 |
+
inputs=[user_chatbot, index_file, upload_to_db],
|
| 108 |
outputs=None,
|
| 109 |
)
|
| 110 |
|
|
|
|
| 124 |
|
| 125 |
video_text_input.submit(video_bot, [test_video_chabot, video_text_input], video_text_output, api_name="video_bot")
|
| 126 |
|
|
|
|
| 127 |
if __name__ == "__main__":
|
| 128 |
demo.launch()
|
utils/chatbot.py
CHANGED
|
@@ -1,31 +1,62 @@
|
|
| 1 |
-
import
|
| 2 |
import os
|
|
|
|
|
|
|
|
|
|
| 3 |
|
|
|
|
|
|
|
| 4 |
import openai
|
| 5 |
import pandas as pd
|
| 6 |
-
|
| 7 |
-
|
|
|
|
| 8 |
from openai.embeddings_utils import distances_from_embeddings
|
| 9 |
|
| 10 |
-
from .work_flow_controller import WorkFlowController
|
| 11 |
from .gpt_processor import QuestionAnswerer
|
|
|
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
| 14 |
class Chatbot:
|
| 15 |
def __init__(self) -> None:
|
| 16 |
self.history = []
|
| 17 |
self.upload_state = "waiting"
|
|
|
|
| 18 |
|
|
|
|
| 19 |
self.knowledge_base = None
|
| 20 |
self.context = None
|
| 21 |
self.context_page_num = None
|
| 22 |
self.context_file_name = None
|
| 23 |
|
| 24 |
-
def build_knowledge_base(self, files):
|
| 25 |
-
work_flow_controller = WorkFlowController(files)
|
| 26 |
self.csv_result_path = work_flow_controller.csv_result_path
|
| 27 |
self.json_result_path = work_flow_controller.json_result_path
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
with open(self.csv_result_path, "r", encoding="UTF-8") as fp:
|
| 30 |
knowledge_base = pd.read_csv(fp)
|
| 31 |
knowledge_base["page_embedding"] = (
|
|
@@ -35,10 +66,81 @@ class Chatbot:
|
|
| 35 |
self.knowledge_base = knowledge_base
|
| 36 |
self.upload_state = "done"
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def clear_state(self):
|
| 39 |
self.context = None
|
| 40 |
self.context_page_num = None
|
| 41 |
self.context_file_name = None
|
|
|
|
| 42 |
self.upload_state = "waiting"
|
| 43 |
self.history = []
|
| 44 |
|
|
@@ -130,9 +232,12 @@ class Chatbot:
|
|
| 130 |
self.context_page_num = self.knowledge_base["page_num"].values[0]
|
| 131 |
self.context_file_name = self.knowledge_base["file_name"].values[0]
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
class VideoChatbot:
|
| 134 |
def __init__(self) -> None:
|
| 135 |
-
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 136 |
self.metadata_keys = ["標題", "逐字稿", "摘要", "關鍵字"]
|
| 137 |
self.metadata = {
|
| 138 |
"c2fK-hxnPSY":{
|
|
|
|
| 1 |
+
import io
|
| 2 |
import os
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import secrets
|
| 6 |
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
import openai
|
| 10 |
import pandas as pd
|
| 11 |
+
from google.oauth2.service_account import Credentials
|
| 12 |
+
from googleapiclient.discovery import build
|
| 13 |
+
from googleapiclient.http import MediaIoBaseDownload, MediaFileUpload
|
| 14 |
from openai.embeddings_utils import distances_from_embeddings
|
| 15 |
|
|
|
|
| 16 |
from .gpt_processor import QuestionAnswerer
|
| 17 |
+
from .work_flow_controller import WorkFlowController
|
| 18 |
|
| 19 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 20 |
+
openai.api_key = OPENAI_API_KEY
|
| 21 |
|
| 22 |
class Chatbot:
|
| 23 |
def __init__(self) -> None:
|
| 24 |
self.history = []
|
| 25 |
self.upload_state = "waiting"
|
| 26 |
+
self.uid = self.__generate_uid()
|
| 27 |
|
| 28 |
+
self.g_drive_service = self.__init_drive_service()
|
| 29 |
self.knowledge_base = None
|
| 30 |
self.context = None
|
| 31 |
self.context_page_num = None
|
| 32 |
self.context_file_name = None
|
| 33 |
|
| 34 |
+
def build_knowledge_base(self, files, upload_mode="僅作這次使用"):
|
| 35 |
+
work_flow_controller = WorkFlowController(files, self.uid)
|
| 36 |
self.csv_result_path = work_flow_controller.csv_result_path
|
| 37 |
self.json_result_path = work_flow_controller.json_result_path
|
| 38 |
|
| 39 |
+
if upload_mode == "上傳至資料庫":
|
| 40 |
+
self.knowledge_base = self.__get_db_knowledge_base()
|
| 41 |
+
else:
|
| 42 |
+
self.knowledge_base = self.__get_local_knowledge_base()
|
| 43 |
+
|
| 44 |
+
def __get_db_knowledge_base(self):
|
| 45 |
+
filename = "knowledge_base.csv"
|
| 46 |
+
db = self.__read_db(self.g_drive_service)
|
| 47 |
+
cur_content = pd.read_csv(self.csv_result_path)
|
| 48 |
+
for _ in range(10):
|
| 49 |
+
try:
|
| 50 |
+
self.__write_into_db(self.g_drive_service, db, cur_content)
|
| 51 |
+
break
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logging.error(e)
|
| 54 |
+
logging.error("Failed to upload to database, retrying...")
|
| 55 |
+
continue
|
| 56 |
+
self.knowledge_base = db
|
| 57 |
+
self.upload_state = "done"
|
| 58 |
+
|
| 59 |
+
def __get_local_knowledge_base(self):
|
| 60 |
with open(self.csv_result_path, "r", encoding="UTF-8") as fp:
|
| 61 |
knowledge_base = pd.read_csv(fp)
|
| 62 |
knowledge_base["page_embedding"] = (
|
|
|
|
| 66 |
self.knowledge_base = knowledge_base
|
| 67 |
self.upload_state = "done"
|
| 68 |
|
| 69 |
+
def __write_into_db(self, service, db: pd.DataFrame, cur_content: pd.DataFrame):
|
| 70 |
+
# db = pd.concat([db, cur_content], ignore_index=True)
|
| 71 |
+
# db.to_csv(f"{self.uid}_knowledge_base.csv", index=False)
|
| 72 |
+
cur_content.to_csv(f"{self.uid}_knowledge_base.csv", index=False)
|
| 73 |
+
media = MediaFileUpload(f"{self.uid}_knowledge_base.csv", resumable=True)
|
| 74 |
+
request = service.files().update(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW", media_body=media).execute()
|
| 75 |
+
|
| 76 |
+
def __init_drive_service(self):
|
| 77 |
+
SCOPES = ['https://www.googleapis.com/auth/drive']
|
| 78 |
+
SERVICE_ACCOUNT_FILE = os.getenv("CREDENTIALS")
|
| 79 |
+
|
| 80 |
+
creds = Credentials.from_service_account_file(SERVICE_ACCOUNT_FILE, scopes=SCOPES)
|
| 81 |
+
|
| 82 |
+
return build('drive', 'v3', credentials=creds)
|
| 83 |
+
|
| 84 |
+
def __read_db(self, service):
|
| 85 |
+
request = service.files().get_media(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW")
|
| 86 |
+
fh = io.BytesIO()
|
| 87 |
+
downloader = MediaIoBaseDownload(fh, request)
|
| 88 |
+
|
| 89 |
+
done = False
|
| 90 |
+
while done is False:
|
| 91 |
+
status, done = downloader.next_chunk()
|
| 92 |
+
print(f"Download {int(status.progress() * 100)}%.")
|
| 93 |
+
|
| 94 |
+
# file_content = fh.getvalue().decode('utf-8')
|
| 95 |
+
fh.seek(0)
|
| 96 |
+
|
| 97 |
+
return pd.read_csv(fh)
|
| 98 |
+
|
| 99 |
+
def __read_file(self, service, filename) -> pd.DataFrame:
|
| 100 |
+
query = f"name='{filename}'"
|
| 101 |
+
results = service.files().list(q=query).execute()
|
| 102 |
+
files = results.get('files', [])
|
| 103 |
+
|
| 104 |
+
file_id = files[0]['id']
|
| 105 |
+
|
| 106 |
+
request = service.files().get_media(fileId=file_id)
|
| 107 |
+
fh = io.BytesIO()
|
| 108 |
+
downloader = MediaIoBaseDownload(fh, request)
|
| 109 |
+
|
| 110 |
+
done = False
|
| 111 |
+
while done is False:
|
| 112 |
+
status, done = downloader.next_chunk()
|
| 113 |
+
print(f"Download {int(status.progress() * 100)}%.")
|
| 114 |
+
|
| 115 |
+
# file_content = fh.getvalue().decode('utf-8')
|
| 116 |
+
fh.seek(0)
|
| 117 |
+
|
| 118 |
+
return pd.read_csv(fh)
|
| 119 |
+
|
| 120 |
+
def __upload_file(self, service):
|
| 121 |
+
results = service.files().list(pageSize=10).execute()
|
| 122 |
+
items = results.get('files', [])
|
| 123 |
+
if not items:
|
| 124 |
+
print('No files found.')
|
| 125 |
+
else:
|
| 126 |
+
print('Files:')
|
| 127 |
+
for item in items:
|
| 128 |
+
print(f"{item['name']} ({item['id']})")
|
| 129 |
+
|
| 130 |
+
media = MediaFileUpload(self.csv_result_path, resumable=True)
|
| 131 |
+
filename_prefix = 'ex_bot_database_'
|
| 132 |
+
filename = filename_prefix + self.uid + '.csv'
|
| 133 |
+
request = service.files().create(media_body=media, body={
|
| 134 |
+
'name': filename,
|
| 135 |
+
'parents': ["1Lp21EZlVlqL-c27VQBC6wTbUC1YpKMsG"] # Optional, to place the file in a specific folder
|
| 136 |
+
}).execute()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
def clear_state(self):
|
| 140 |
self.context = None
|
| 141 |
self.context_page_num = None
|
| 142 |
self.context_file_name = None
|
| 143 |
+
self.knowledge_base = None
|
| 144 |
self.upload_state = "waiting"
|
| 145 |
self.history = []
|
| 146 |
|
|
|
|
| 232 |
self.context_page_num = self.knowledge_base["page_num"].values[0]
|
| 233 |
self.context_file_name = self.knowledge_base["file_name"].values[0]
|
| 234 |
|
| 235 |
+
def __generate_uid(self):
|
| 236 |
+
return secrets.token_hex(8)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
class VideoChatbot:
|
| 240 |
def __init__(self) -> None:
|
|
|
|
| 241 |
self.metadata_keys = ["標題", "逐字稿", "摘要", "關鍵字"]
|
| 242 |
self.metadata = {
|
| 243 |
"c2fK-hxnPSY":{
|
utils/work_flow_controller.py
CHANGED
|
@@ -20,10 +20,11 @@ processors = {
|
|
| 20 |
|
| 21 |
|
| 22 |
class WorkFlowController:
|
| 23 |
-
def __init__(self, file_src) -> None:
|
| 24 |
# check if the file_path is list
|
| 25 |
# self.file_paths = self.__get_file_name(file_src)
|
| 26 |
self.file_paths = [x.name for x in file_src]
|
|
|
|
| 27 |
|
| 28 |
print(self.file_paths)
|
| 29 |
|
|
@@ -83,6 +84,7 @@ class WorkFlowController:
|
|
| 83 |
|
| 84 |
for i, _ in enumerate(file["file_content"]):
|
| 85 |
# use i+1 to meet the index of file_content
|
|
|
|
| 86 |
file["file_content"][i + 1][
|
| 87 |
"page_content"
|
| 88 |
] = translator.translate_to_chinese(
|
|
@@ -97,33 +99,34 @@ class WorkFlowController:
|
|
| 97 |
# process file content
|
| 98 |
# return processed data
|
| 99 |
if not file["is_chinese"]:
|
|
|
|
| 100 |
file = self.__translate_to_chinese(file)
|
|
|
|
| 101 |
file = self.__get_embedding(file)
|
|
|
|
| 102 |
file = self.__get_summary(file)
|
| 103 |
return file
|
| 104 |
|
| 105 |
def __dump_to_json(self):
|
| 106 |
with open(
|
| 107 |
-
os.path.join(os.getcwd(), "
|
| 108 |
) as f:
|
| 109 |
print(
|
| 110 |
"Dumping to json, the path is: "
|
| 111 |
-
+ os.path.join(os.getcwd(), "
|
| 112 |
)
|
| 113 |
-
self.json_result_path = os.path.join(os.getcwd(), "
|
| 114 |
json.dump(self.files_info, f, indent=4, ensure_ascii=False)
|
| 115 |
|
| 116 |
def __construct_knowledge_base_dataframe(self):
|
| 117 |
rows = []
|
| 118 |
for file_path, content in self.files_info.items():
|
| 119 |
-
file_full_content = content["file_full_content"]
|
| 120 |
for page_num, page_details in content["file_content"].items():
|
| 121 |
row = {
|
| 122 |
"file_name": content["file_name"],
|
| 123 |
"page_num": page_details["page_num"],
|
| 124 |
"page_content": page_details["page_content"],
|
| 125 |
"page_embedding": page_details["page_embedding"],
|
| 126 |
-
"file_full_content": file_full_content,
|
| 127 |
}
|
| 128 |
rows.append(row)
|
| 129 |
|
|
@@ -132,19 +135,18 @@ class WorkFlowController:
|
|
| 132 |
"page_num",
|
| 133 |
"page_content",
|
| 134 |
"page_embedding",
|
| 135 |
-
"file_full_content",
|
| 136 |
]
|
| 137 |
df = pd.DataFrame(rows, columns=columns)
|
| 138 |
return df
|
| 139 |
|
| 140 |
def __dump_to_csv(self):
|
| 141 |
df = self.__construct_knowledge_base_dataframe()
|
| 142 |
-
df.to_csv(os.path.join(os.getcwd(), "
|
| 143 |
print(
|
| 144 |
"Dumping to csv, the path is: "
|
| 145 |
-
+ os.path.join(os.getcwd(), "
|
| 146 |
)
|
| 147 |
-
self.csv_result_path = os.path.join(os.getcwd(), "
|
| 148 |
|
| 149 |
def __get_file_name(self, file_src):
|
| 150 |
file_paths = [x.name for x in file_src]
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
class WorkFlowController:
|
| 23 |
+
def __init__(self, file_src, uid) -> None:
|
| 24 |
# check if the file_path is list
|
| 25 |
# self.file_paths = self.__get_file_name(file_src)
|
| 26 |
self.file_paths = [x.name for x in file_src]
|
| 27 |
+
self.uid = uid
|
| 28 |
|
| 29 |
print(self.file_paths)
|
| 30 |
|
|
|
|
| 84 |
|
| 85 |
for i, _ in enumerate(file["file_content"]):
|
| 86 |
# use i+1 to meet the index of file_content
|
| 87 |
+
print("Translating page: " + str(i + 1))
|
| 88 |
file["file_content"][i + 1][
|
| 89 |
"page_content"
|
| 90 |
] = translator.translate_to_chinese(
|
|
|
|
| 99 |
# process file content
|
| 100 |
# return processed data
|
| 101 |
if not file["is_chinese"]:
|
| 102 |
+
print("Translating to chinese...")
|
| 103 |
file = self.__translate_to_chinese(file)
|
| 104 |
+
print("Getting embedding...")
|
| 105 |
file = self.__get_embedding(file)
|
| 106 |
+
print("Getting summary...")
|
| 107 |
file = self.__get_summary(file)
|
| 108 |
return file
|
| 109 |
|
| 110 |
def __dump_to_json(self):
|
| 111 |
with open(
|
| 112 |
+
os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json"), "w", encoding="utf-8"
|
| 113 |
) as f:
|
| 114 |
print(
|
| 115 |
"Dumping to json, the path is: "
|
| 116 |
+
+ os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json")
|
| 117 |
)
|
| 118 |
+
self.json_result_path = os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json")
|
| 119 |
json.dump(self.files_info, f, indent=4, ensure_ascii=False)
|
| 120 |
|
| 121 |
def __construct_knowledge_base_dataframe(self):
|
| 122 |
rows = []
|
| 123 |
for file_path, content in self.files_info.items():
|
|
|
|
| 124 |
for page_num, page_details in content["file_content"].items():
|
| 125 |
row = {
|
| 126 |
"file_name": content["file_name"],
|
| 127 |
"page_num": page_details["page_num"],
|
| 128 |
"page_content": page_details["page_content"],
|
| 129 |
"page_embedding": page_details["page_embedding"],
|
|
|
|
| 130 |
}
|
| 131 |
rows.append(row)
|
| 132 |
|
|
|
|
| 135 |
"page_num",
|
| 136 |
"page_content",
|
| 137 |
"page_embedding",
|
|
|
|
| 138 |
]
|
| 139 |
df = pd.DataFrame(rows, columns=columns)
|
| 140 |
return df
|
| 141 |
|
| 142 |
def __dump_to_csv(self):
|
| 143 |
df = self.__construct_knowledge_base_dataframe()
|
| 144 |
+
df.to_csv(os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv"), index=False)
|
| 145 |
print(
|
| 146 |
"Dumping to csv, the path is: "
|
| 147 |
+
+ os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv")
|
| 148 |
)
|
| 149 |
+
self.csv_result_path = os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv")
|
| 150 |
|
| 151 |
def __get_file_name(self, file_src):
|
| 152 |
file_paths = [x.name for x in file_src]
|