holy commited on
Commit
bdb1da0
·
1 Parent(s): ad3bc1e

app.py add

Browse files
Files changed (2) hide show
  1. app.py +244 -0
  2. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import time
4
+ import pdfplumber
5
+ from dotenv import load_dotenv
6
+ import torch
7
+ from transformers import (
8
+ BertJapaneseTokenizer,
9
+ BertModel,
10
+ AutoTokenizer,
11
+ AutoModelForCausalLM,
12
+ pipeline,
13
+ BitsAndBytesConfig
14
+ )
15
+ from langchain.vectorstores import FAISS
16
+ from langchain.chains import ConversationalRetrievalChain
17
+ from langchain.memory import ConversationBufferMemory
18
+ from langchain.llms import HuggingFacePipeline
19
+ from langchain.embeddings import HuggingFaceEmbeddings
20
+ from langchain_huggingface import HuggingFaceEndpoint
21
+
22
+ load_dotenv()
23
+
24
+ list_llm = [
25
+ "meta-llama/Meta-Llama-3-8B-Instruct",
26
+ "mistralai/Mistral-7B-Instruct-v0.2",
27
+ "rinna/llama-3-youko-8b",
28
+ "rinna/japanese-gpt-neox-3.6b"
29
+ ]
30
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
31
+
32
+ # 日本語PDFのテキスト抽出
33
+ def extract_text_from_pdf(file_path):
34
+ with pdfplumber.open(file_path) as pdf:
35
+ pages = [page.extract_text() for page in pdf.pages]
36
+ return " ".join(pages)
37
+
38
+ # モデルとトークナイザの初期化
39
+ tokenizer_bert = BertJapaneseTokenizer.from_pretrained(
40
+ 'cl-tohoku/bert-base-japanese',
41
+ clean_up_tokenization_spaces=True
42
+ )
43
+ model_bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese')
44
+
45
+ def split_text_simple(text, chunk_size=1024):
46
+ return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
47
+
48
+ def create_db(splits):
49
+ embeddings = HuggingFaceEmbeddings(
50
+ model_name='sonoisa/sentence-bert-base-ja-mean-tokens'
51
+ )
52
+ vectordb = FAISS.from_texts(splits, embeddings)
53
+ return vectordb
54
+
55
+ def initialize_llmchain(
56
+ llm_model,
57
+ temperature,
58
+ max_tokens,
59
+ top_k,
60
+ vector_db,
61
+ retries=5,
62
+ delay=5
63
+ ):
64
+ attempt = 0
65
+ while attempt < retries:
66
+ try:
67
+ # ローカルモデルの場合
68
+ if "rinna" in llm_model.lower():
69
+ # デバイスの自動検出
70
+ if torch.cuda.is_available():
71
+ device_map = "auto"
72
+ torch_dtype = torch.float16
73
+ # GPUがある場合は量子化を使用
74
+ quantization_config = BitsAndBytesConfig(
75
+ load_in_4bit=True,
76
+ bnb_4bit_compute_dtype=torch.float16,
77
+ bnb_4bit_use_double_quant=True,
78
+ bnb_4bit_quant_type="nf4"
79
+ )
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ llm_model,
82
+ device_map=device_map,
83
+ quantization_config=quantization_config
84
+ )
85
+ else:
86
+ device_map = {"": "cpu"}
87
+ torch_dtype = torch.float32
88
+ # CPUの場合は量子化を使用せずにモデルをロード
89
+ model = AutoModelForCausalLM.from_pretrained(
90
+ llm_model,
91
+ device_map=device_map,
92
+ torch_dtype=torch_dtype
93
+ )
94
+ tokenizer = AutoTokenizer.from_pretrained(llm_model, use_fast=False)
95
+ pipe = pipeline(
96
+ "text-generation",
97
+ model=model,
98
+ tokenizer=tokenizer,
99
+ max_new_tokens=max_tokens,
100
+ temperature=temperature
101
+ )
102
+ llm = HuggingFacePipeline(pipeline=pipe)
103
+ # エンドポイントモデルの場合
104
+ elif "meta-llama" in llm_model.lower() or "mistralai" in llm_model.lower():
105
+ # パラメータを直接指定
106
+ llm = HuggingFaceEndpoint(
107
+ endpoint_url=f"https://api-inference.huggingface.co/models/{llm_model}",
108
+ huggingfacehub_api_token=os.getenv("HF_TOKEN"),
109
+ temperature=temperature,
110
+ max_new_tokens=max_tokens,
111
+ top_k=top_k
112
+ )
113
+ else:
114
+ # その他のモデルの場合(必要に応じて追加)
115
+ raise Exception(f"Unsupported model: {llm_model}")
116
+
117
+ # 共通の処理
118
+ memory = ConversationBufferMemory(
119
+ memory_key="chat_history",
120
+ output_key='answer',
121
+ return_messages=True
122
+ )
123
+ retriever = vector_db.as_retriever()
124
+ qa_chain = ConversationalRetrievalChain.from_llm(
125
+ llm,
126
+ retriever=retriever,
127
+ memory=memory,
128
+ return_source_documents=True,
129
+ verbose=False
130
+ )
131
+ return qa_chain
132
+ except Exception as e:
133
+ if "Could not authenticate with huggingface_hub" in str(e):
134
+ time.sleep(delay)
135
+ attempt += 1
136
+ else:
137
+ raise Exception(f"Error initializing QA chain: {str(e)}")
138
+ raise Exception(f"Failed to initialize after {retries} attempts")
139
+
140
+ def process_pdf(file):
141
+ try:
142
+ if file is None:
143
+ return None, "Please upload a PDF file."
144
+ text = extract_text_from_pdf(file.name)
145
+ splits = split_text_simple(text)
146
+ vdb = create_db(splits)
147
+ return vdb, "PDF processed and vector database created."
148
+ except Exception as e:
149
+ return None, f"Error processing PDF: {str(e)}"
150
+
151
+ def initialize_qa_chain(
152
+ llm_index,
153
+ temperature,
154
+ max_tokens,
155
+ top_k,
156
+ vector_db
157
+ ):
158
+ try:
159
+ if vector_db is None:
160
+ return None, "Please process a PDF first."
161
+ llm_name = list_llm[llm_index]
162
+ chain = initialize_llmchain(
163
+ llm_name,
164
+ temperature,
165
+ max_tokens,
166
+ top_k,
167
+ vector_db
168
+ )
169
+ return chain, "QA Chatbot initialized with selected LLM."
170
+ except Exception as e:
171
+ return None, f"Error initializing QA chain: {str(e)}"
172
+
173
+ def update_chat(msg, history, chain):
174
+ try:
175
+ if chain is None:
176
+ return history + [("User", msg), ("Assistant", "Please initialize the QA Chatbot first.")]
177
+ response = chain({"question": msg, "chat_history": history})
178
+ return history + [("User", msg), ("Assistant", response['answer'])]
179
+ except Exception as e:
180
+ return history + [("User", msg), ("Assistant", f"Error: {str(e)}")]
181
+
182
+ def demo():
183
+ with gr.Blocks() as demo:
184
+ vector_db = gr.State(value=None)
185
+ qa_chain = gr.State(value=None)
186
+
187
+ with gr.Tab("Step 1 - Upload and Process"):
188
+ with gr.Row():
189
+ document = gr.File(label="Upload your Japanese PDF document", file_types=["pdf"])
190
+ with gr.Row():
191
+ process_btn = gr.Button("Process PDF")
192
+ process_output = gr.Textbox(label="Processing Output")
193
+
194
+ with gr.Tab("Step 2 - Initialize QA Chatbot"):
195
+ with gr.Row():
196
+ llm_btn = gr.Radio(list_llm_simple, label="Select LLM Model", type="index")
197
+ llm_temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="Temperature", value=0.7)
198
+ max_tokens = gr.Slider(minimum=128, maximum=2048, step=128, label="Max Tokens", value=1024)
199
+ top_k = gr.Slider(minimum=1, maximum=10, step=1, label="Top K", value=3)
200
+ with gr.Row():
201
+ init_qa_btn = gr.Button("Initialize QA Chatbot")
202
+ init_output = gr.Textbox(label="Initialization Output")
203
+
204
+ with gr.Tab("Step 3 - Chat with your Document"):
205
+ chatbot = gr.Chatbot()
206
+ message = gr.Textbox(label="Ask a question")
207
+ with gr.Row():
208
+ send_btn = gr.Button("Send")
209
+ clear_chat_btn = gr.Button("Clear Chat")
210
+ reset_all_btn = gr.Button("Reset All")
211
+
212
+ process_btn.click(
213
+ process_pdf,
214
+ inputs=[document],
215
+ outputs=[vector_db, process_output]
216
+ )
217
+
218
+ init_qa_btn.click(
219
+ initialize_qa_chain,
220
+ inputs=[llm_btn, llm_temperature, max_tokens, top_k, vector_db],
221
+ outputs=[qa_chain, init_output]
222
+ )
223
+
224
+ send_btn.click(
225
+ update_chat,
226
+ inputs=[message, chatbot, qa_chain],
227
+ outputs=[chatbot]
228
+ )
229
+
230
+ # Clear Chatボタン:チャット履歴のみをクリア
231
+ clear_chat_btn.click(
232
+ lambda: None,
233
+ outputs=[chatbot]
234
+ )
235
+
236
+ # Reset Allボタン:チャット履歴、PDFデータ、チャットボットの状態をすべてクリア
237
+ reset_all_btn.click(
238
+ lambda: (None, None, None),
239
+ outputs=[chatbot, vector_db, qa_chain]
240
+ )
241
+ return demo
242
+
243
+ if __name__ == "__main__":
244
+ demo().launch(server_name="0.0.0.0", server_port=8188)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.41.2
2
+ pdfplumber==0.9.0
3
+ transformers==4.35.0
4
+ torch==2.0.1
5
+ sentence-transformers==2.2.2
6
+ langchain==0.0.263
7
+ pydantic==1.10.12
8
+ faiss-cpu==1.7.4
9
+ langchain-huggingface==0.0.5
10
+ accelerate==0.34.2
11
+ python-dotenv
12
+ bitsandbytes