alen commited on
Commit
4c385f7
·
verified ·
1 Parent(s): cf9ca3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -35
app.py CHANGED
@@ -1,19 +1,24 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
2
  from langchain_community.llms import LlamaCpp
3
- from langchain.prompts import ChatPromptTemplate
4
 
5
- vector_db_path = "vectorstores/db_faiss"
6
 
7
- llm_json = LlamaCpp(
8
- model_path="Llama-3.1-8B-Instruct.Q5_K_M.gguf",
9
- temperature=0,
10
- max_tokens=512,
11
- top_p=1,
12
- # callback_manager=callback_manager,
13
- verbose=True,
14
- format='json'
15
- )
16
- print(llm_json)
17
 
18
  llm = LlamaCpp(
19
  model_path="Llama-3.1-8B-Instruct.Q5_K_M.gguf",
@@ -23,31 +28,178 @@ llm = LlamaCpp(
23
  # callback_manager=callback_manager,
24
  verbose=True, # Verbose is required to pass to the callback manager
25
  )
26
- template = """Bạn là trợ lý ảo thông thái tên là Aleni. bạn hãy sử dụng dữ liệu dưới đây để trả lời câu hỏi,
27
- nếu không có thông tin hãy đưa ra câu trả lời sát nhất với câu hỏi từ các thông tin tìm được hoặc tự suy luận
28
- Question: {question}
29
- Chỉ đưa ra các câu trả lời hữu ích.
30
- Helpful answer:
31
- """
32
- # Content: {content}
33
- def respond(message, history, system_message, path_document):
34
- prompt = ChatPromptTemplate.from_template(template)
35
- llm_chain = prompt | llm
36
- respon = ''
37
- for chunk in llm_chain.stream(message):
38
- respon += chunk
39
- # print(chunk.content, end="", flush=True)
40
- yield respon
41
 
42
-
43
- demo = gr.ChatInterface(
44
- respond,
45
- additional_inputs=[
46
- # gr.Textbox(value="Trả lời câu hỏi CHỈ dựa trên ngữ cảnh sau không có thì bảo không có câu trả lời:", label="System message"),
47
- gr.UploadButton("Upload a file", file_count="single"),
48
- # gr.DownloadButton("Download the file")
49
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  if __name__ == "__main__":
53
  demo.launch()
 
1
  import gradio as gr
2
+ from langchain_community.document_loaders import PyPDFLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain_community.embeddings import GPT4AllEmbeddings
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain.schema.runnable import RunnablePassthrough
8
+ # from langchain.prompts import ChatPromptTemplate
9
+ # from langchain_community.chat_models import ChatOllama
10
+ from prompt_template import *
11
+ from langgraph.graph import END, StateGraph
12
  from langchain_community.llms import LlamaCpp
 
13
 
14
+ # local_llm = 'aleni_ox'
15
 
16
+ # llm = ChatOllama(model=local_llm,
17
+ # keep_alive="3h",
18
+ # max_tokens=512,
19
+ # temperature=0,
20
+ # # callbacks=[StreamingStdOutCallbackHandler()]
21
+ # )
 
 
 
 
22
 
23
  llm = LlamaCpp(
24
  model_path="Llama-3.1-8B-Instruct.Q5_K_M.gguf",
 
28
  # callback_manager=callback_manager,
29
  verbose=True, # Verbose is required to pass to the callback manager
30
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ question_router = router_prompt | llm | JsonOutputParser()
33
+ generate_chain = generate_prompt | llm | StrOutputParser()
34
+ query_chain = query_prompt | llm | JsonOutputParser()
35
+ llm_chain = nomalqa_prompt | llm | StrOutputParser()
36
+
37
+ def generate(state):
38
+ """
39
+ Generate answer
40
+
41
+ Args:
42
+ state (dict): The current graph state
43
+
44
+ Returns:
45
+ state (dict): New key added to state, generation, that contains LLM generation
46
+ """
47
+
48
+ print("Step: Đang tạo câu trả lời từ những gì tìm được")
49
+ question = state["question"]
50
+ context = state["context"]
51
+ # return question, context
52
+ return {'question': question, 'context': context}
53
+
54
+ # respon=''
55
+
56
+ # for chunk in generate_chain.stream({"context": context, "question": question}):
57
+ # respon += chunk
58
+ # print(chunk, end="", flush=True)
59
+
60
+ def transform_query(state):
61
+ """
62
+ Transform user question to web search
63
+
64
+ Args:
65
+ state (dict): The current graph state
66
+
67
+ Returns:
68
+ state (dict): Appended search query
69
+ """
70
+
71
+ print("Step: Tối ưu câu hỏi của người dùng")
72
+ question = state['question']
73
+ gen_query = query_chain.invoke({"question": question})
74
+ search_query = gen_query["query"]
75
+ return {"search_query": search_query}
76
+
77
+ def web_search(state):
78
+ """
79
+ Web search based on the question
80
+
81
+ Args:
82
+ state (dict): The current graph state
83
+
84
+ Returns:
85
+ state (dict): Appended web results to context
86
+ """
87
+
88
+ search_query = state['search_query']
89
+ print(f'Step: Đang tìm kiếm web cho: "{search_query}"')
90
+
91
+ # Web search tool call
92
+ search_result = web_search_tool.invoke(search_query)
93
+ print("Search result:", search_result)
94
+ return {"context": search_result}
95
+
96
+ def route_question(state):
97
+ """
98
+ route question to web search or generation.
99
+
100
+ Args:
101
+ state (dict): The current graph state
102
+
103
+ Returns:
104
+ str: Next node to call
105
+ """
106
+
107
+ print("Step: Routing Query")
108
+ question = state['question']
109
+ output = question_router.invoke({"question": question})
110
+ print('Lựa chọn của AI là: ', output)
111
+ if output == "web_search":
112
+ # print("Step: Routing Query to Web Search")
113
+ return "websearch"
114
+ elif output == 'generate':
115
+ # print("Step: Routing Query to Generation")
116
+ return "generate"
117
+
118
+ workflow = StateGraph(State)
119
+ workflow.add_node("websearch", web_search)
120
+ workflow.add_node("transform_query", transform_query)
121
+ workflow.add_node("generate", generate)
122
+
123
+ # Build the edges
124
+ workflow.set_conditional_entry_point(
125
+ route_question,
126
+ {
127
+ "websearch": "transform_query",
128
+ "generate": "generate",
129
+ },
130
  )
131
+ workflow.add_edge("transform_query", "websearch")
132
+ workflow.add_edge("websearch", "generate")
133
+ workflow.add_edge("generate", END)
134
+
135
+ # Compile the workflow
136
+ local_agent = workflow.compile()
137
+
138
+ def run_agent(query):
139
+ local_agent.invoke({"question": query})
140
+ print("=======")
141
+
142
+ def QA(question: str, history: list, type: str):
143
+ if 'Agent' in type:
144
+ gr.Info("Đang tạo câu trả lời!")
145
+ respon = ''
146
+ # print(question)
147
+ output = local_agent.invoke({"question": question})
148
+ # print(output)
149
+ context = output['context']
150
+ questions = output['question']
151
+ for chunk in generate_chain.stream({"context": context, "question": questions}):
152
+ respon += chunk
153
+ print(chunk, end="", flush=True)
154
+ yield respon
155
+ else:
156
+ gr.Info("Đang tạo câu trả lời!")
157
+ print(question, history)
158
+ respon = ''
159
+ for chunk in llm_chain.stream(question):
160
+ respon += chunk
161
+ print(chunk, end="", flush=True)
162
+ yield respon
163
+
164
+ def create_db(doc: str) -> str:
165
+ loader = PyPDFLoader(doc)
166
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=40)
167
+
168
+ chunked_documents = loader.load_and_split(text_splitter)
169
+ embedding_model = GPT4AllEmbeddings(model_name="all-MiniLM-L6-v2.gguf2.f16.gguf", gpt4all_kwargs={'allow_download': 'True'})
170
+ db = FAISS.from_documents(chunked_documents, embedding_model)
171
+ gr.Info("Đã tải lên dữ liệu từ PDF!")
172
+
173
+ retriever = db.as_retriever(
174
+ search_type="similarity",
175
+ search_kwargs= {"k": 3}
176
+ )
177
+ llm_chain = (
178
+ {
179
+ "context": retriever,
180
+ "question": RunnablePassthrough()}
181
+ | nomaldoc_prompt
182
+ | llm
183
+ )
184
+
185
+
186
+ with gr.Blocks(fill_height=True) as demo:
187
+ with gr.Row(equal_height=True):
188
+
189
+ with gr.Column(scale=1):
190
+
191
+ democ2 = gr.Interface(
192
+ create_db,
193
+ [gr.File(file_count='single')],
194
+ None,
195
+ )
196
+ with gr.Column(scale=2):
197
+ democ1 = gr.ChatInterface(
198
+ QA,
199
+ additional_inputs=[gr.Radio(['None', 'Agent', 'Doc', 'Coin'], )]
200
+
201
+
202
+ )
203
 
204
  if __name__ == "__main__":
205
  demo.launch()