File size: 4,849 Bytes
9b2f45b
4820f61
 
 
 
 
 
 
 
4c385f7
4820f61
9b2f45b
4820f61
 
 
eb3b40c
57a65f9
9332631
 
 
284dcd9
57a65f9
 
 
 
4820f61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b877de0
4c385f7
 
 
4820f61
4c385f7
 
4820f61
4c385f7
4820f61
 
 
 
4c385f7
4820f61
 
 
4c385f7
 
 
 
 
 
 
 
 
 
 
4820f61
4c385f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4820f61
4c385f7
 
4820f61
4c385f7
 
 
4820f61
 
 
 
 
4c385f7
4820f61
 
 
 
 
 
 
 
 
 
 
4c385f7
4820f61
 
4c385f7
4820f61
 
ed9d94e
 
 
 
 
 
 
 
 
 
 
 
 
4c385f7
4820f61
ed9d94e
4820f61
 
 
 
 
9b2f45b
 
ed9d94e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
from langchain_community.tools import DuckDuckGoSearchRun
from langchain.prompts import PromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate
import pickle
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.graph import END, StateGraph
from huggingface_hub import hf_hub_download

from langchain_community.llms import LlamaCpp
wrapper = DuckDuckGoSearchAPIWrapper(max_results=5)
web_search_tool = DuckDuckGoSearchRun(api_wrapper=wrapper)

llm = LlamaCpp(
    model_path="Llama-3.1-8B-Instruct.Q5_K_M.gguf",
    temperature=0,
    max_tokens=512,
    n_ctx = 2000,
    top_p=1,
    # callback_manager=callback_manager,
    verbose=True,  # Verbose is required to pass to the callback manager
)
chat_history = list()
try:
    with open("template.pkl", 'rb') as file:
        template_abox = pickle.load(file)
except:
    hf_hub_download(repo_id="linl03/dataAboxChat",local_dir="./", filename="template.pkl", repo_type="dataset")
    with open("./template.pkl", 'rb') as file:
        template_abox = pickle.load(file)

router_prompt = PromptTemplate(
    template=template_abox["router_template"],
    input_variables=["question"],
)
generate_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            template_abox["system_prompt"],
        ),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{question}"),
    ]
)
query_prompt = PromptTemplate(
    template=template_abox["query_template"],
    input_variables=["question"],
)

remind_prompt = PromptTemplate(
    template=template_abox["schedule_template"],
    input_variables=["time"],
)

question_router = router_prompt | llm | JsonOutputParser()
generate_chain = generate_prompt | llm | StrOutputParser()
query_chain = query_prompt | llm | JsonOutputParser()
# llm_chain = nomalqa_prompt | llm | StrOutputParser()


class State(TypedDict):

    question : str
    generation : str
    search_query : str
    context : str
    
def generate(state):

    print("Step: Đang tạo câu trả lời")
    question = state["question"]
    context = state["context"]
    # return question, context
    return {'question': question, 'context': context}


def transform_query(state):

    print("Step: Tối ưu câu hỏi của người dùng")
    question = state['question']
    gen_query = query_chain.invoke({"question": question})
    print(gen_query)
    search_query = gen_query["query"]
    return {"search_query": search_query}

def web_search(state):

    search_query = state['search_query']
    print(f'Step: Đang tìm kiếm web cho: "{search_query}"')
    
    # Web search tool call
    search_result = web_search_tool.invoke(search_query)
    print("Search result:", search_result)
    return {"context": search_result}

def route_question(state):

    print("Step: Routing Query")
    question = state['question']
    output = question_router.invoke({"question": question})
    print('Lựa chọn của AI là: ', output)
    if output['choice'] == "web_search":
        # print("Step: Routing Query to Web Search")
        return "websearch"
    elif output['choice'] == 'generate':
        # print("Step: Routing Query to Generation")
        return "generate"

def Agents():
    workflow = StateGraph(State)
    workflow.add_node("websearch", web_search)
    workflow.add_node("transform_query", transform_query)
    workflow.add_node("generate", generate)

    # Build the edges
    workflow.set_conditional_entry_point(
        route_question,
        {
            "websearch": "transform_query",
            "generate": "generate",
        },
    )
    workflow.add_edge("transform_query", "websearch")
    workflow.add_edge("websearch", "generate")
    workflow.add_edge("generate", END)

    # Compile the workflow
    return workflow.compile()

def QA(question: str, history: list):
    # print(question.text, question.files, history, type)
    local_agent = Agents()
    gr.Info("Đang tạo câu trả lời!")
    response = ''
    output = local_agent.invoke({"question": question})
    context = output['context']
    questions = output['question']
    for chunk in generate_chain.stream({"context": context, "question": questions, "chat_history": chat_history}):
        response += chunk
        print(chunk, end="", flush=True)
        yield response
            
    chat_history.append(HumanMessage(content=question))
    chat_history.append(AIMessage(content=response))
        

demo = gr.ChatInterface(
    QA,
    fill_height=True,
    multimodal=True,
    title="Box Chat(Agent)",
)

if __name__ == "__main__":
    demo.launch()