aerovfx commited on
Commit
517693e
·
verified ·
1 Parent(s): 6af8013

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -76
app.py CHANGED
@@ -1,88 +1,176 @@
 
1
  import gradio as gr
 
 
 
 
2
  from langchain.chains import ConversationalRetrievalChain
3
  from langchain.memory import ConversationBufferMemory
4
- from langchain_huggingface import HuggingFaceEndpoint
5
-
6
- # ================================
7
- # Cấu hình LLM
8
- # ================================
9
- HF_TOKEN = "hf_xxxxxxx" # thay bằng token riêng
10
- LLM_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
11
- TEMPERATURE = 0.7
12
- MAX_TOKENS = 512
13
-
14
- # ================================
15
- # Hàm khởi tạo QA Chain
16
- # ================================
17
- def initialize_llmchain(vector_db):
18
- if vector_db is None:
19
- raise ValueError("Vector DB chưa được load! Kiểm tra embeddings/path.")
20
-
21
- # LLM chạy CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  llm = HuggingFaceEndpoint(
23
- endpoint_url=f"https://api-inference.huggingface.co/models/{LLM_NAME}",
24
- client_kwargs={"token": HF_TOKEN},
25
- model_kwargs={
26
- "temperature": TEMPERATURE,
27
- "max_new_tokens": MAX_TOKENS,
28
- "device": "cpu"
29
- }
30
  )
31
 
32
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
 
 
 
 
33
 
 
34
  qa_chain = ConversationalRetrievalChain.from_llm(
35
- llm=llm,
36
- retriever=vector_db.as_retriever(),
 
37
  memory=memory,
38
- return_source_documents=True
 
39
  )
40
-
41
  return qa_chain
42
 
43
- # ================================
44
- # Hàm xử lý conversation
45
- # ================================
46
- def conversation(message, chat_history, qa_chain):
47
- """
48
- message: câu hỏi mới từ user
49
- chat_history: list [(user, bot), ...]
50
- qa_chain: chain đã khởi tạo
51
- """
52
- if not chat_history:
53
- chat_history = []
54
-
55
- formatted_chat_history = [{"role": "user", "content": h[0]} if len(h) > 0 else {"role": "user", "content": ""} for h in chat_history]
56
- formatted_chat_history += [{"role": "assistant", "content": h[1]} if len(h) > 1 else {"role": "assistant", "content": ""} for h in chat_history]
57
-
58
- # Lấy response từ chain
59
- result = qa_chain.invoke({
60
- "question": message,
61
- "chat_history": formatted_chat_history
62
- })
63
-
64
- answer = result["answer"]
65
- chat_history.append((message, answer))
66
- return chat_history, chat_history
67
-
68
- # ================================
69
- # Khởi tạo vector DB và QA Chain
70
- # ================================
71
- # TODO: thay bằng load vector DB thực tế của bạn
72
- vector_db = None # ví dụ: FAISS.load_local("my_faiss_index")
73
- qa_chain = initialize_llmchain(vector_db)
74
-
75
- # ================================
76
- # Setup Gradio UI
77
- # ================================
78
- with gr.Blocks() as demo:
79
- chatbot = gr.Chatbot(height=480, type="messages")
80
- txt = gr.Textbox(show_label=False, placeholder="Nhập câu hỏi...")
81
- state = gr.State([])
82
-
83
- def respond(message, chat_history):
84
- return conversation(message, chat_history, qa_chain)
85
-
86
- txt.submit(respond, [txt, state], [chatbot, state])
87
-
88
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+
4
+ from langchain_community.vectorstores import FAISS
5
+ from langchain_community.document_loaders import PyPDFLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.memory import ConversationBufferMemory
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain_community.llms import HuggingFaceEndpoint
11
+
12
+ # ------------------------------
13
+ # Configuration & LLM Selection
14
+ # ------------------------------
15
+ list_llm = [
16
+ "meta-llama/Meta-Llama-3-8B-Instruct",
17
+ "mistralai/Mistral-7B-Instruct-v0.2"
18
+ ]
19
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
20
+
21
+ # Token đọc từ Space secret
22
+ api_token = os.getenv("hf_token") # Space secret, không hardcode
23
+
24
+ # ------------------------------
25
+ # PDF Loading & Splitting
26
+ # ------------------------------
27
+ def load_doc(list_file_path):
28
+ pages = []
29
+ for file_path in list_file_path:
30
+ try:
31
+ loader = PyPDFLoader(file_path)
32
+ pages.extend(loader.load())
33
+ except Exception as e:
34
+ print(f"Error loading {file_path}: {e}")
35
+ text_splitter = RecursiveCharacterTextSplitter(
36
+ chunk_size=1024,
37
+ chunk_overlap=32
38
+ )
39
+ return text_splitter.split_documents(pages)
40
+
41
+ # ------------------------------
42
+ # Vector Database Creation
43
+ # ------------------------------
44
+ def create_db(doc_splits):
45
+ embeddings = HuggingFaceEmbeddings() # CPU-only
46
+ vectordb = FAISS.from_documents(doc_splits, embeddings)
47
+ return vectordb
48
+
49
+ # ------------------------------
50
+ # Initialize LLM + QA Chain
51
+ # ------------------------------
52
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
53
  llm = HuggingFaceEndpoint(
54
+ repo_id=llm_model,
55
+ huggingfacehub_api_token=api_token,
56
+ temperature=temperature,
57
+ max_new_tokens=max_tokens,
58
+ top_k=top_k,
 
 
59
  )
60
 
61
+ memory = ConversationBufferMemory(
62
+ memory_key="chat_history",
63
+ output_key='answer',
64
+ return_messages=True
65
+ )
66
 
67
+ retriever = vector_db.as_retriever()
68
  qa_chain = ConversationalRetrievalChain.from_llm(
69
+ llm,
70
+ retriever=retriever,
71
+ chain_type="stuff",
72
  memory=memory,
73
+ return_source_documents=True,
74
+ verbose=False,
75
  )
 
76
  return qa_chain
77
 
78
+ # ------------------------------
79
+ # Database Initialization
80
+ # ------------------------------
81
+ def initialize_database(list_file_obj):
82
+ list_file_path = [x.name for x in list_file_obj if x is not None]
83
+ doc_splits = load_doc(list_file_path)
84
+ vector_db = create_db(doc_splits)
85
+ return vector_db, "Database created!"
86
+
87
+ # ------------------------------
88
+ # LLM Initialization
89
+ # ------------------------------
90
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
91
+ llm_name = list_llm[llm_option]
92
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
93
+ return qa_chain, "QA chain initialized. Chatbot is ready!"
94
+
95
+ # ------------------------------
96
+ # Conversation Utilities
97
+ # ------------------------------
98
+ def format_chat_history(chat_history, max_messages=5):
99
+ formatted = []
100
+ for user_msg, bot_msg in chat_history[-max_messages:]:
101
+ formatted.append(f"User: {user_msg}")
102
+ formatted.append(f"Assistant: {bot_msg}")
103
+ return formatted
104
+
105
+ def conversation(qa_chain, message, history):
106
+ formatted_history = format_chat_history(history)
107
+ try:
108
+ response = qa_chain.invoke({"question": message, "chat_history": formatted_history})
109
+ answer = response["answer"]
110
+ if "Helpful Answer:" in answer:
111
+ answer = answer.split("Helpful Answer:")[-1]
112
+
113
+ sources = response["source_documents"]
114
+ top_sources = [(s.page_content.strip(), s.metadata.get("page", 0) + 1) for s in sources[:3]]
115
+ while len(top_sources) < 3:
116
+ top_sources.append(("", 0))
117
+
118
+ new_history = history + [(message, answer)]
119
+ return qa_chain, gr.update(value=""), new_history, *sum(top_sources, ())
120
+ except Exception as e:
121
+ print(f"Conversation error: {e}")
122
+ return qa_chain, gr.update(value=""), history, "", 0, "", 0, "", 0
123
+
124
+ # ------------------------------
125
+ # Gradio UI
126
+ # ------------------------------
127
+ def demo():
128
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink")) as demo:
129
+ vector_db = gr.State()
130
+ qa_chain = gr.State()
131
+
132
+ gr.HTML("<center><h1>AERO RAG (CPU-only, Safe Secret)</h1></center>")
133
+ gr.Markdown("<b>Query your PDF documents!</b> CPU-only mode. Token must be stored in Hugging Face Space secret `hf_token`.")
134
+
135
+ with gr.Row():
136
+ # Left Column
137
+ with gr.Column(scale=1):
138
+ document = gr.Files(file_count="multiple", file_types=[".pdf"], label="Upload PDFs")
139
+ db_btn = gr.Button("Create vector DB")
140
+ db_progress = gr.Textbox(value="Not initialized", show_label=False)
141
+
142
+ llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple[0], type="index")
143
+ slider_temperature = gr.Slider(0.01, 1.0, 0.5, 0.1, label="Temperature")
144
+ slider_maxtokens = gr.Slider(128, 4096, 1024, 128, label="Max New Tokens")
145
+ slider_topk = gr.Slider(1, 10, 3, 1, label="Top-K Tokens")
146
+ qachain_btn = gr.Button("Initialize QA Chatbot")
147
+ llm_progress = gr.Textbox(value="Not initialized", show_label=False)
148
+
149
+ # Right Column
150
+ with gr.Column(scale=8):
151
+ chatbot = gr.Chatbot(height=480)
152
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2)
153
+ source1_page = gr.Number(label="Page")
154
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2)
155
+ source2_page = gr.Number(label="Page")
156
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2)
157
+ source3_page = gr.Number(label="Page")
158
+ msg = gr.Textbox(placeholder="Ask a question")
159
+ submit_btn = gr.Button("Submit")
160
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
161
+
162
+ # Event Bindings
163
+ db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_progress])
164
+ qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
165
+ outputs=[qa_chain, llm_progress])
166
+ msg.submit(conversation, inputs=[qa_chain, msg, chatbot],
167
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
168
+ submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot],
169
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
170
+ clear_btn.click(lambda: [None, "", 0, "", 0, "", 0],
171
+ inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
172
+
173
+ demo.queue().launch(debug=True)
174
+
175
+ if __name__ == "__main__":
176
+ demo()