Spaces:
Build error
Build error
Upload 38 files
Browse files- app.py +13 -0
- requirements.txt +23 -0
- src/.DS_Store +0 -0
- src/agents/entrance_eval_agent/flow.py +18 -0
- src/agents/entrance_eval_agent/func.py +4 -0
- src/agents/exercise_gen_agent/flow.py +18 -0
- src/agents/exercise_gen_agent/func.py +4 -0
- src/agents/highlight_explain_agent/__pycache__/flow.cpython-311.pyc +0 -0
- src/agents/highlight_explain_agent/__pycache__/func.cpython-311.pyc +0 -0
- src/agents/highlight_explain_agent/__pycache__/prompt.cpython-311.pyc +0 -0
- src/agents/highlight_explain_agent/flow.py +27 -0
- src/agents/highlight_explain_agent/func.py +25 -0
- src/agents/highlight_explain_agent/prompt.py +39 -0
- src/agents/lesson_rag_agent/flow.py +18 -0
- src/agents/lesson_rag_agent/func.py +4 -0
- src/agents/primary_chatbot/__pycache__/flow.cpython-311.pyc +0 -0
- src/agents/primary_chatbot/__pycache__/func.cpython-311.pyc +0 -0
- src/agents/primary_chatbot/__pycache__/prompt.cpython-311.pyc +0 -0
- src/agents/primary_chatbot/flow.py +157 -0
- src/agents/primary_chatbot/func.py +169 -0
- src/agents/primary_chatbot/prompt.py +177 -0
- src/apis/__pycache__/create_app.cpython-311.pyc +0 -0
- src/apis/create_app.py +23 -0
- src/apis/interfaces/__pycache__/chat_interface.cpython-311.pyc +0 -0
- src/apis/interfaces/chat_interface.py +44 -0
- src/apis/routers/__pycache__/chat_router.cpython-311.pyc +0 -0
- src/apis/routers/chat_router.py +49 -0
- src/config/__pycache__/llm.cpython-311.pyc +0 -0
- src/config/__pycache__/prompt.cpython-311.pyc +0 -0
- src/config/__pycache__/vector_store.cpython-311.pyc +0 -0
- src/config/constant.py +0 -0
- src/config/llm.py +14 -0
- src/config/prompt.py +206 -0
- src/config/vector_store.py +37 -0
- src/utils/__pycache__/helper.cpython-311.pyc +0 -0
- src/utils/__pycache__/logger.cpython-311.pyc +0 -0
- src/utils/helper.py +27 -0
- src/utils/logger.py +65 -0
app.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
|
| 3 |
+
load_dotenv(override=True)
|
| 4 |
+
|
| 5 |
+
from src.apis.create_app import create_app, api_router
|
| 6 |
+
import uvicorn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
app = create_app()
|
| 10 |
+
|
| 11 |
+
app.include_router(api_router)
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
uvicorn.run("app:app", host="0.0.0.0", port=3002)
|
requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
langgraph
|
| 2 |
+
langchain
|
| 3 |
+
python-dotenv
|
| 4 |
+
motor
|
| 5 |
+
langchain-community
|
| 6 |
+
langchain-mongodb
|
| 7 |
+
pytz
|
| 8 |
+
PyJWT==2.8.0
|
| 9 |
+
python_jose==3.3.0
|
| 10 |
+
pydantic[email]
|
| 11 |
+
jose
|
| 12 |
+
langchain-google-genai
|
| 13 |
+
python-dateutil
|
| 14 |
+
pandas
|
| 15 |
+
openpyxl
|
| 16 |
+
langchain-redis
|
| 17 |
+
redis
|
| 18 |
+
bs4
|
| 19 |
+
duckduckgo-search
|
| 20 |
+
firebase-admin
|
| 21 |
+
python-dotenv
|
| 22 |
+
fastapi
|
| 23 |
+
uvicorn[standard]
|
src/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
src/agents/entrance_eval_agent/flow.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import StateGraph, START, END
|
| 2 |
+
from src.config.llm import llm_2_0
|
| 3 |
+
from .func import State
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PrimaryChatBot:
|
| 7 |
+
def __init__(self):
|
| 8 |
+
self.builder = StateGraph(State)
|
| 9 |
+
|
| 10 |
+
@staticmethod
|
| 11 |
+
def routing(state: State):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
def node(self):
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
def edge(self):
|
| 18 |
+
pass
|
src/agents/entrance_eval_agent/func.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict
|
| 2 |
+
|
| 3 |
+
class State(TypedDict):
|
| 4 |
+
pass
|
src/agents/exercise_gen_agent/flow.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import StateGraph, START, END
|
| 2 |
+
from src.config.llm import llm_2_0
|
| 3 |
+
from .func import State
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PrimaryChatBot:
|
| 7 |
+
def __init__(self):
|
| 8 |
+
self.builder = StateGraph(State)
|
| 9 |
+
|
| 10 |
+
@staticmethod
|
| 11 |
+
def routing(state: State):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
def node(self):
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
def edge(self):
|
| 18 |
+
pass
|
src/agents/exercise_gen_agent/func.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict
|
| 2 |
+
|
| 3 |
+
class State(TypedDict):
|
| 4 |
+
pass
|
src/agents/highlight_explain_agent/__pycache__/flow.cpython-311.pyc
ADDED
|
Binary file (2.2 kB). View file
|
|
|
src/agents/highlight_explain_agent/__pycache__/func.cpython-311.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
src/agents/highlight_explain_agent/__pycache__/prompt.cpython-311.pyc
ADDED
|
Binary file (1.85 kB). View file
|
|
|
src/agents/highlight_explain_agent/flow.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import StateGraph, START, END
|
| 2 |
+
from .func import State, highlight_explain
|
| 3 |
+
from langgraph.graph.state import CompiledStateGraph
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class HighlightExplainAgent:
|
| 7 |
+
def __init__(self):
|
| 8 |
+
self.builder = StateGraph(State)
|
| 9 |
+
|
| 10 |
+
@staticmethod
|
| 11 |
+
def routing(state: State):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
def node(self):
|
| 15 |
+
self.builder.add_node("highlight_explain", highlight_explain)
|
| 16 |
+
|
| 17 |
+
def edge(self):
|
| 18 |
+
self.builder.add_edge(START, "highlight_explain")
|
| 19 |
+
self.builder.add_edge("highlight_explain", END)
|
| 20 |
+
|
| 21 |
+
def __call__(self) -> CompiledStateGraph:
|
| 22 |
+
self.node()
|
| 23 |
+
self.edge()
|
| 24 |
+
return self.builder.compile()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
highlight_workflow = HighlightExplainAgent()()
|
src/agents/highlight_explain_agent/func.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict, AnyStr
|
| 2 |
+
|
| 3 |
+
from .prompt import highlight_explain_chain
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class State(TypedDict):
|
| 7 |
+
domain: AnyStr
|
| 8 |
+
highlight_terms: AnyStr
|
| 9 |
+
adjacent_paragraphs: AnyStr
|
| 10 |
+
question: AnyStr
|
| 11 |
+
explanation: AnyStr
|
| 12 |
+
language: AnyStr
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def highlight_explain(state: State):
|
| 16 |
+
response = await highlight_explain_chain.ainvoke(
|
| 17 |
+
{
|
| 18 |
+
"domain": state["domain"],
|
| 19 |
+
"highlight_terms": state["highlight_terms"],
|
| 20 |
+
"adjacent_paragraphs": state["adjacent_paragraphs"],
|
| 21 |
+
"question": state["question"],
|
| 22 |
+
"language": state["language"],
|
| 23 |
+
}
|
| 24 |
+
)
|
| 25 |
+
return {"explanation": response["explanation"]}
|
src/agents/highlight_explain_agent/prompt.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 3 |
+
from typing import Literal, Annotated, AnyStr, TypedDict
|
| 4 |
+
from src.config.llm import llm_2_0 as llm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class HighlightExplain(TypedDict):
|
| 8 |
+
"""Explain the highlight terms in a concise and easy to understand manner."""
|
| 9 |
+
|
| 10 |
+
explanation: Annotated[AnyStr, "The explanation of the highlight terms."]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
highlight_explain_prompt = ChatPromptTemplate(
|
| 14 |
+
[
|
| 15 |
+
(
|
| 16 |
+
"system",
|
| 17 |
+
"""You are a expert in explaining the highlight terms in {domain} domain.
|
| 18 |
+
You are given the higlight terms, adjacent paragraphs of the highlight terms.
|
| 19 |
+
Your task is to explain the highlight terms in a concise and easy to understand manner.
|
| 20 |
+
You are also given the user question.
|
| 21 |
+
|
| 22 |
+
Explanation must be primary in {language} language. But you can use {domain} domain terms in explanation.
|
| 23 |
+
""",
|
| 24 |
+
),
|
| 25 |
+
(
|
| 26 |
+
"human",
|
| 27 |
+
"""
|
| 28 |
+
User question: {question}
|
| 29 |
+
Highlight terms: {highlight_terms}
|
| 30 |
+
Adjacent paragraphs: {adjacent_paragraphs}
|
| 31 |
+
""",
|
| 32 |
+
),
|
| 33 |
+
]
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
highlight_explain_chain = highlight_explain_prompt | llm.with_structured_output(
|
| 38 |
+
HighlightExplain
|
| 39 |
+
)
|
src/agents/lesson_rag_agent/flow.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import StateGraph, START, END
|
| 2 |
+
from src.config.llm import llm_2_0
|
| 3 |
+
from .func import State
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PrimaryChatBot:
|
| 7 |
+
def __init__(self):
|
| 8 |
+
self.builder = StateGraph(State)
|
| 9 |
+
|
| 10 |
+
@staticmethod
|
| 11 |
+
def routing(state: State):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
def node(self):
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
def edge(self):
|
| 18 |
+
pass
|
src/agents/lesson_rag_agent/func.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict
|
| 2 |
+
|
| 3 |
+
class State(TypedDict):
|
| 4 |
+
pass
|
src/agents/primary_chatbot/__pycache__/flow.cpython-311.pyc
ADDED
|
Binary file (8.08 kB). View file
|
|
|
src/agents/primary_chatbot/__pycache__/func.cpython-311.pyc
ADDED
|
Binary file (8.73 kB). View file
|
|
|
src/agents/primary_chatbot/__pycache__/prompt.cpython-311.pyc
ADDED
|
Binary file (7.86 kB). View file
|
|
|
src/agents/primary_chatbot/flow.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import StateGraph, START, END
|
| 2 |
+
from langgraph.graph.state import CompiledStateGraph
|
| 3 |
+
from .func import (
|
| 4 |
+
StateRAGAccuracy,
|
| 5 |
+
StateRAGSpeed,
|
| 6 |
+
trim_history,
|
| 7 |
+
route,
|
| 8 |
+
transform_query,
|
| 9 |
+
retrieve_document,
|
| 10 |
+
grade_document,
|
| 11 |
+
generate_answer_rag,
|
| 12 |
+
grade_hallucinations,
|
| 13 |
+
gen_answer_normal,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PrimaryChatBotAccuracy:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.builder = StateGraph(StateRAGAccuracy)
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def routing_after_route(state: StateRAGAccuracy):
|
| 23 |
+
if state["route_response"] == "vectorstore":
|
| 24 |
+
return "transform_query"
|
| 25 |
+
else:
|
| 26 |
+
return "generate_answer_normal"
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def routing_after_retrieve_document(state: StateRAGAccuracy):
|
| 30 |
+
return (
|
| 31 |
+
"grade_document"
|
| 32 |
+
if len(state["documents"]) != 0
|
| 33 |
+
else "generate_answer_normal"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def route_after_grade_document(state: StateRAGAccuracy):
|
| 38 |
+
return (
|
| 39 |
+
"generate_answer_rag"
|
| 40 |
+
if len(state["documents"]) != 0
|
| 41 |
+
else "generate_answer_normal"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def routing_check_pass_grade_hallucinations(state: StateRAGAccuracy):
|
| 46 |
+
return END if state["grade_response"] == "yes" else "generate_answer_normal"
|
| 47 |
+
|
| 48 |
+
def node(self):
|
| 49 |
+
self.builder.add_node("trim_history", trim_history)
|
| 50 |
+
self.builder.add_node("route", route)
|
| 51 |
+
self.builder.add_node("transform_query", transform_query)
|
| 52 |
+
self.builder.add_node("retrieve_document", retrieve_document)
|
| 53 |
+
self.builder.add_node("grade_document", grade_document)
|
| 54 |
+
self.builder.add_node("generate_answer_rag", generate_answer_rag)
|
| 55 |
+
self.builder.add_node("grade_hallucinations", grade_hallucinations)
|
| 56 |
+
self.builder.add_node("generate_answer_normal", gen_answer_normal)
|
| 57 |
+
|
| 58 |
+
def edge(self):
|
| 59 |
+
self.builder.add_edge(START, "trim_history")
|
| 60 |
+
self.builder.add_edge("trim_history", "route")
|
| 61 |
+
self.builder.add_conditional_edges(
|
| 62 |
+
"route",
|
| 63 |
+
self.routing_after_route,
|
| 64 |
+
{
|
| 65 |
+
"transform_query": "transform_query",
|
| 66 |
+
"generate_answer_normal": "generate_answer_normal",
|
| 67 |
+
},
|
| 68 |
+
)
|
| 69 |
+
self.builder.add_edge("transform_query", "retrieve_document")
|
| 70 |
+
self.builder.add_conditional_edges(
|
| 71 |
+
"retrieve_document",
|
| 72 |
+
self.routing_after_retrieve_document,
|
| 73 |
+
{
|
| 74 |
+
"grade_document": "grade_document",
|
| 75 |
+
"generate_answer_normal": "generate_answer_normal",
|
| 76 |
+
},
|
| 77 |
+
)
|
| 78 |
+
self.builder.add_conditional_edges(
|
| 79 |
+
"grade_document",
|
| 80 |
+
self.route_after_grade_document,
|
| 81 |
+
{
|
| 82 |
+
"generate_answer_rag": "generate_answer_rag",
|
| 83 |
+
"generate_answer_normal": "generate_answer_normal",
|
| 84 |
+
},
|
| 85 |
+
)
|
| 86 |
+
self.builder.add_edge("generate_answer_rag", "grade_hallucinations")
|
| 87 |
+
self.builder.add_conditional_edges(
|
| 88 |
+
"grade_hallucinations",
|
| 89 |
+
self.routing_check_pass_grade_hallucinations,
|
| 90 |
+
{
|
| 91 |
+
END: END,
|
| 92 |
+
"generate_answer_normal": "generate_answer_normal",
|
| 93 |
+
},
|
| 94 |
+
)
|
| 95 |
+
self.builder.add_edge("generate_answer_normal", END)
|
| 96 |
+
|
| 97 |
+
def __call__(self) -> CompiledStateGraph:
|
| 98 |
+
self.node()
|
| 99 |
+
self.edge()
|
| 100 |
+
return self.builder.compile()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class PrimaryChatBotSpeed:
|
| 104 |
+
def __init__(self):
|
| 105 |
+
self.builder = StateGraph(StateRAGSpeed)
|
| 106 |
+
|
| 107 |
+
@staticmethod
|
| 108 |
+
def routing_after_retrieve_document(state: StateRAGAccuracy):
|
| 109 |
+
return (
|
| 110 |
+
"generate_answer_rag"
|
| 111 |
+
if len(state["documents"]) != 0
|
| 112 |
+
else "generate_answer_normal"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def routing_after_gen_answer_rag(state: StateRAGAccuracy):
|
| 117 |
+
return END if state["document_id_selected"] else "generate_answer_normal"
|
| 118 |
+
|
| 119 |
+
def node(self):
|
| 120 |
+
self.builder.add_node("trim_history", trim_history)
|
| 121 |
+
self.builder.add_node("transform_query", transform_query)
|
| 122 |
+
self.builder.add_node("retrieve_document", retrieve_document)
|
| 123 |
+
self.builder.add_node("generate_answer_rag", generate_answer_rag)
|
| 124 |
+
self.builder.add_node("generate_answer_normal", gen_answer_normal)
|
| 125 |
+
|
| 126 |
+
def edge(self):
|
| 127 |
+
self.builder.add_edge(START, "trim_history")
|
| 128 |
+
self.builder.add_edge("trim_history", "transform_query")
|
| 129 |
+
self.builder.add_edge("transform_query", "retrieve_document")
|
| 130 |
+
self.builder.add_conditional_edges(
|
| 131 |
+
"retrieve_document",
|
| 132 |
+
self.routing_after_retrieve_document,
|
| 133 |
+
{
|
| 134 |
+
"generate_answer_rag": "generate_answer_rag",
|
| 135 |
+
"generate_answer_normal": "generate_answer_normal",
|
| 136 |
+
},
|
| 137 |
+
)
|
| 138 |
+
self.builder.add_conditional_edges(
|
| 139 |
+
"generate_answer_rag",
|
| 140 |
+
self.routing_after_gen_answer_rag,
|
| 141 |
+
{
|
| 142 |
+
END: END,
|
| 143 |
+
"generate_answer_normal": "generate_answer_normal",
|
| 144 |
+
},
|
| 145 |
+
)
|
| 146 |
+
self.builder.add_edge("generate_answer_normal", END)
|
| 147 |
+
|
| 148 |
+
def __call__(self) -> CompiledStateGraph:
|
| 149 |
+
self.node()
|
| 150 |
+
self.edge()
|
| 151 |
+
return self.builder.compile()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
rag_speed = PrimaryChatBotSpeed()()
|
| 155 |
+
rag_accuracy = PrimaryChatBotAccuracy()()
|
| 156 |
+
|
| 157 |
+
#
|
src/agents/primary_chatbot/func.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import TypedDict, Optional, List, Literal
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
from src.utils.helper import (
|
| 5 |
+
fake_token_counter,
|
| 6 |
+
convert_list_context_source_to_str,
|
| 7 |
+
convert_message,
|
| 8 |
+
)
|
| 9 |
+
from src.utils.logger import logger
|
| 10 |
+
from langchain_core.messages import trim_messages, AnyMessage
|
| 11 |
+
from src.config.vector_store import vector_store_chatbot, vector_store_tutor
|
| 12 |
+
from .prompt import (
|
| 13 |
+
RouteQuery,
|
| 14 |
+
route_chain,
|
| 15 |
+
transform_query_chain,
|
| 16 |
+
ExtractFilter,
|
| 17 |
+
extract_filter_chain,
|
| 18 |
+
GradeDocuments,
|
| 19 |
+
GenerateAnswer,
|
| 20 |
+
GradeHallucinations,
|
| 21 |
+
gen_normal_answer_chain,
|
| 22 |
+
gen_answer_rag_chain,
|
| 23 |
+
grade_documents_chain,
|
| 24 |
+
gen_answer_rag_chain,
|
| 25 |
+
grade_documents_chain,
|
| 26 |
+
grade_hallucinations_chain,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class StateRAGAccuracy(TypedDict):
|
| 31 |
+
user_query: str | AnyMessage
|
| 32 |
+
route_response: str
|
| 33 |
+
messages_history: list
|
| 34 |
+
documents: list[Document]
|
| 35 |
+
filter: dict
|
| 36 |
+
llm_response: AnyMessage
|
| 37 |
+
grade_response: Literal["yes", "no"]
|
| 38 |
+
language: str
|
| 39 |
+
document_id_selected: Optional[List]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class StateRAGSpeed(TypedDict):
|
| 43 |
+
user_query: str | AnyMessage
|
| 44 |
+
messages_history: list
|
| 45 |
+
documents: list[Document]
|
| 46 |
+
filter: dict
|
| 47 |
+
llm_response: AnyMessage
|
| 48 |
+
language: str
|
| 49 |
+
document_id_selected: Optional[List]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def trim_history(state: StateRAGAccuracy | StateRAGSpeed):
|
| 53 |
+
history = (
|
| 54 |
+
convert_message(state["messages_history"])
|
| 55 |
+
if state.get("messages_history")
|
| 56 |
+
else None
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if not history:
|
| 60 |
+
return {"messages_history": []}
|
| 61 |
+
|
| 62 |
+
chat_message_history = trim_messages(
|
| 63 |
+
history,
|
| 64 |
+
strategy="last",
|
| 65 |
+
token_counter=fake_token_counter,
|
| 66 |
+
max_tokens=int(os.getenv("HISTORY_TOKEN_LIMIT", 2000)),
|
| 67 |
+
start_on="human",
|
| 68 |
+
end_on="ai",
|
| 69 |
+
include_system=False,
|
| 70 |
+
allow_partial=False,
|
| 71 |
+
)
|
| 72 |
+
return {"messages_history": chat_message_history}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
async def route(state: StateRAGAccuracy):
|
| 76 |
+
logger.info(f"routing")
|
| 77 |
+
question = state["user_query"]
|
| 78 |
+
chat_history = state.get("messages_history", None)
|
| 79 |
+
|
| 80 |
+
route_response: RouteQuery = await route_chain.ainvoke(
|
| 81 |
+
{"question": question, "chat_history": chat_history}
|
| 82 |
+
)
|
| 83 |
+
logger.info(f"Route response: {route_response.datasource}")
|
| 84 |
+
return {"route_response": route_response.datasource}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
async def transform_query(state: StateRAGAccuracy | StateRAGSpeed):
|
| 88 |
+
question = state["user_query"]
|
| 89 |
+
chat_history = state.get("messages_history", None)
|
| 90 |
+
transform_response = await transform_query_chain.ainvoke(
|
| 91 |
+
{"question": question, "chat_history": chat_history}
|
| 92 |
+
)
|
| 93 |
+
logger.info(f"Transform response: {transform_response.content}")
|
| 94 |
+
return {"user_query": transform_response.content}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
async def retrieve_document(state: StateRAGAccuracy):
|
| 98 |
+
question = state["user_query"]
|
| 99 |
+
filter = state.get("filter", {})
|
| 100 |
+
if filter:
|
| 101 |
+
retriever = vector_store_tutor.as_retriever(
|
| 102 |
+
search_type="similarity_score_threshold",
|
| 103 |
+
search_kwargs={"k": 3, "score_threshold": 0.3},
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
retriever = vector_store_chatbot.as_retriever(
|
| 107 |
+
search_type="similarity_score_threshold",
|
| 108 |
+
search_kwargs={"k": 3, "score_threshold": 0.0},
|
| 109 |
+
)
|
| 110 |
+
documents = retriever.invoke(question, filter=filter)
|
| 111 |
+
show_doc = " \n =============\n".join([doc.page_content for doc in documents])
|
| 112 |
+
logger.info(f"Retrieved documents: {show_doc}")
|
| 113 |
+
return {"documents": documents}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
async def grade_document(state: StateRAGAccuracy):
|
| 117 |
+
question = state["user_query"]
|
| 118 |
+
documents = state["documents"]
|
| 119 |
+
inputs_bach = [
|
| 120 |
+
{"question": question, "document": doc.page_content} for doc in documents
|
| 121 |
+
]
|
| 122 |
+
grade_document_response: list[GradeDocuments] = await grade_documents_chain.abatch(
|
| 123 |
+
inputs_bach
|
| 124 |
+
)
|
| 125 |
+
logger.info(f"Grade response: {grade_document_response}")
|
| 126 |
+
document_index = [
|
| 127 |
+
index
|
| 128 |
+
for index, doc in enumerate(grade_document_response)
|
| 129 |
+
if doc.binary_score == "yes"
|
| 130 |
+
]
|
| 131 |
+
filtered_documents = [documents[i] for i in document_index]
|
| 132 |
+
|
| 133 |
+
return {"documents": filtered_documents}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
async def generate_answer_rag(state: StateRAGAccuracy):
|
| 137 |
+
question = state["user_query"]
|
| 138 |
+
documents = state["documents"]
|
| 139 |
+
language = state["language"]
|
| 140 |
+
context_str = convert_list_context_source_to_str(documents)
|
| 141 |
+
|
| 142 |
+
gen_answer_response: GenerateAnswer = await gen_answer_rag_chain.ainvoke(
|
| 143 |
+
{"question": question, "context": context_str, "language": language}
|
| 144 |
+
)
|
| 145 |
+
logger.info(f"Generate answer response: {gen_answer_response}")
|
| 146 |
+
id_selected = gen_answer_response.selected_document_index
|
| 147 |
+
return {
|
| 148 |
+
"llm_response": gen_answer_response.answer,
|
| 149 |
+
"document_id_selected": id_selected,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
async def grade_hallucinations(state: StateRAGAccuracy):
|
| 154 |
+
question = state["user_query"]
|
| 155 |
+
llm_response = state["llm_response"]
|
| 156 |
+
grade_response: GradeHallucinations = await grade_hallucinations_chain.ainvoke(
|
| 157 |
+
{"question": question, "generation": llm_response}
|
| 158 |
+
)
|
| 159 |
+
return {"grade_response": grade_response.binary_score}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
async def gen_answer_normal(state: StateRAGAccuracy):
|
| 163 |
+
question = state["user_query"]
|
| 164 |
+
history = state["messages_history"]
|
| 165 |
+
gen_answer_response = await gen_normal_answer_chain.ainvoke(
|
| 166 |
+
{"question": question, "history": history}
|
| 167 |
+
)
|
| 168 |
+
final_response = gen_answer_response.content + "\nNguồn thông tin: Kiến thức của AI"
|
| 169 |
+
return {"llm_response": final_response}
|
src/agents/primary_chatbot/prompt.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 3 |
+
from typing import Literal
|
| 4 |
+
from src.config.llm import llm_2_0 as llm
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RouteQuery(BaseModel):
|
| 9 |
+
"""Route a user query to the most relevant datasource."""
|
| 10 |
+
|
| 11 |
+
datasource: Literal["vectorstore", "casual_convo"] = Field(
|
| 12 |
+
...,
|
| 13 |
+
description="Given a user question choose to route it to casual_convo or a vectorstore.",
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ExtractFilter(BaseModel):
|
| 18 |
+
"""Extract job level and job title from user question."""
|
| 19 |
+
|
| 20 |
+
job_level: str = Field(description="The level of the job the user is asking about.")
|
| 21 |
+
job_title: str = Field(description="The title of the job the user is asking about.")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class GradeDocuments(BaseModel):
|
| 25 |
+
"""Binary score for relevance check on retrieved documents."""
|
| 26 |
+
|
| 27 |
+
binary_score: str = Field(
|
| 28 |
+
description="Documents are relevant to the question, 'yes' or 'no'"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class GenerateAnswer(BaseModel):
|
| 33 |
+
"""Generate an answer based on the provided documents."""
|
| 34 |
+
|
| 35 |
+
answer: str = Field(description="Generated answer based on the provided documents.")
|
| 36 |
+
selected_document_index: Optional[list[int]] = Field(
|
| 37 |
+
description="Index of the selected document. If not have relevant document then leave it None"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class GradeHallucinations(BaseModel):
|
| 42 |
+
"""Binary score for grounding of generation answer in provided facts."""
|
| 43 |
+
|
| 44 |
+
binary_score: Literal["yes", "no"] = Field(
|
| 45 |
+
description="Whether the answer is grounded in the provided facts. 'yes' if the answer is supported by facts, 'no' if the answer contains information not present or contradicting the given facts"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
route_prompt = ChatPromptTemplate(
|
| 50 |
+
[
|
| 51 |
+
(
|
| 52 |
+
"system",
|
| 53 |
+
"""You are an expert at routing the user's question to vectorstore or casual_convo in {topic} platform.
|
| 54 |
+
choose vectorstore if the question is related to {topic} and casual_convo otherwise. \n
|
| 55 |
+
|
| 56 |
+
example:
|
| 57 |
+
user: Hi are you [this is a random question not related to {topic} so route to casual_convo] : casual_convo
|
| 58 |
+
user: Calculate,... [this question is related to education, system information so route to vectorstore] : vectorstore""",
|
| 59 |
+
),
|
| 60 |
+
("placeholder", "{history}"),
|
| 61 |
+
("human", "{question}"),
|
| 62 |
+
]
|
| 63 |
+
).partial(topic="education")
|
| 64 |
+
|
| 65 |
+
re_write_query_prompt = ChatPromptTemplate(
|
| 66 |
+
[
|
| 67 |
+
(
|
| 68 |
+
"system",
|
| 69 |
+
"""You a question re-writer that converts an input question to a better version that is optimized
|
| 70 |
+
for vectorstore retrieval, and very concise. Look at the input and try to reason about the underlying semantic intent/meaning. The input can also be a
|
| 71 |
+
follow up question, look at the chat history to re-write the question to include necessary info from the chat history to a better version that is optimized
|
| 72 |
+
for vectorstore retrieval without any other info needed. [the topic of convo will be generally around {topic} topic. You need to re-write query base on history and include keyword related to this topic""",
|
| 73 |
+
),
|
| 74 |
+
("placeholder", "{history}"),
|
| 75 |
+
(
|
| 76 |
+
"human",
|
| 77 |
+
"{question}",
|
| 78 |
+
),
|
| 79 |
+
]
|
| 80 |
+
).partial(topic="education")
|
| 81 |
+
|
| 82 |
+
extract_filter_prompt = ChatPromptTemplate.from_messages(
|
| 83 |
+
[
|
| 84 |
+
(
|
| 85 |
+
"system",
|
| 86 |
+
"""You are an expert at extracting metadata from the user's question about {topic} topic and using it to filter the retrieved documents.
|
| 87 |
+
""",
|
| 88 |
+
),
|
| 89 |
+
("placeholder", "{history}"),
|
| 90 |
+
("human", "{question}"),
|
| 91 |
+
]
|
| 92 |
+
).partial(topic="education")
|
| 93 |
+
|
| 94 |
+
check_relevant_document_prompt = ChatPromptTemplate(
|
| 95 |
+
[
|
| 96 |
+
(
|
| 97 |
+
"system",
|
| 98 |
+
"""
|
| 99 |
+
You are a grader assessing relevance of a retrieved document to a user question.
|
| 100 |
+
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
|
| 101 |
+
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
|
| 102 |
+
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
| 103 |
+
Then, give a score ranges from 0 to 1, with higher values indicating a stronger match and the more corresponding keywords.
|
| 104 |
+
""",
|
| 105 |
+
),
|
| 106 |
+
(
|
| 107 |
+
"human",
|
| 108 |
+
"Retrieved document: \n\n {document} \nvs\n User question: {question}",
|
| 109 |
+
),
|
| 110 |
+
]
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
gen_answer_rag_prompt = ChatPromptTemplate(
|
| 114 |
+
[
|
| 115 |
+
(
|
| 116 |
+
"system",
|
| 117 |
+
"""You are chat bot related to {topic}. You are asked to generate an answer based on the provided documents.
|
| 118 |
+
Your are given context related to job description of a job position. If the context not provided, you just say 'không có tài liệu liên quan'
|
| 119 |
+
Answer in {language} language.
|
| 120 |
+
|
| 121 |
+
Context:
|
| 122 |
+
```
|
| 123 |
+
{context}
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
""",
|
| 127 |
+
),
|
| 128 |
+
(
|
| 129 |
+
"human",
|
| 130 |
+
"""
|
| 131 |
+
Question: {question}
|
| 132 |
+
""",
|
| 133 |
+
),
|
| 134 |
+
]
|
| 135 |
+
).partial(topic="education", language="vietnamese")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
grade_answer_prompt = ChatPromptTemplate(
|
| 139 |
+
[
|
| 140 |
+
(
|
| 141 |
+
"system",
|
| 142 |
+
"""You are a grader assessing whether an answer addresses / resolves a question \n
|
| 143 |
+
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.
|
| 144 |
+
If the LLM Generation is saying that it doesnt know or not sure or stating to keep the questions relevant to topic , grade it as 'yes'.""",
|
| 145 |
+
),
|
| 146 |
+
(
|
| 147 |
+
"human",
|
| 148 |
+
"If the LLM Generation is saying that it doesnt know or not sure or stating to keep the questions relevant to topic , grade it as 'yes'. User question: \n\n {question} \n\n LLM generation: {generation}",
|
| 149 |
+
),
|
| 150 |
+
]
|
| 151 |
+
)
|
| 152 |
+
gen_normal_answer_prompt = ChatPromptTemplate(
|
| 153 |
+
[
|
| 154 |
+
(
|
| 155 |
+
"system",
|
| 156 |
+
"""Bạn là chatbot giải đáp câu hỏi của người dùng dựa trên đoạn hội thoại liên quan đến lĩnh vực giáo dục
|
| 157 |
+
""",
|
| 158 |
+
),
|
| 159 |
+
("placeholder", "{history}"),
|
| 160 |
+
("human", "{question}"),
|
| 161 |
+
]
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
route_chain = route_prompt | llm.with_structured_output(RouteQuery)
|
| 166 |
+
transform_query_chain = re_write_query_prompt | llm
|
| 167 |
+
extract_filter_chain = extract_filter_prompt | llm.with_structured_output(ExtractFilter)
|
| 168 |
+
grade_documents_chain = check_relevant_document_prompt | llm.with_structured_output(
|
| 169 |
+
GradeDocuments
|
| 170 |
+
)
|
| 171 |
+
gen_answer_rag_chain = gen_answer_rag_prompt | llm.with_structured_output(
|
| 172 |
+
GenerateAnswer
|
| 173 |
+
)
|
| 174 |
+
gen_normal_answer_chain = gen_normal_answer_prompt | llm
|
| 175 |
+
grade_hallucinations_chain = grade_answer_prompt | llm.with_structured_output(
|
| 176 |
+
GradeHallucinations
|
| 177 |
+
)
|
src/apis/__pycache__/create_app.cpython-311.pyc
ADDED
|
Binary file (980 Bytes). View file
|
|
|
src/apis/create_app.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, APIRouter
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from src.apis.routers.chat_router import router as router_chat
|
| 4 |
+
|
| 5 |
+
api_router = APIRouter()
|
| 6 |
+
api_router.include_router(router_chat)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_app():
|
| 10 |
+
app = FastAPI(
|
| 11 |
+
docs_url="/",
|
| 12 |
+
title="AI Serivce",
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
app.add_middleware(
|
| 16 |
+
CORSMiddleware,
|
| 17 |
+
allow_origins=["*"],
|
| 18 |
+
allow_credentials=True,
|
| 19 |
+
allow_methods=["*"],
|
| 20 |
+
allow_headers=["*"],
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
return app
|
src/apis/interfaces/__pycache__/chat_interface.cpython-311.pyc
ADDED
|
Binary file (2.61 kB). View file
|
|
|
src/apis/interfaces/chat_interface.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ChatBody(BaseModel):
|
| 6 |
+
query: str = Field(..., title="User's query messages")
|
| 7 |
+
history: Optional[list] = Field(None, title="Chat history")
|
| 8 |
+
language: Optional[str] = Field("en", title="Language")
|
| 9 |
+
topic: Optional[str] = Field("education", title="Topic")
|
| 10 |
+
|
| 11 |
+
model_config = {
|
| 12 |
+
"json_schema_extra": {
|
| 13 |
+
"example": {
|
| 14 |
+
"query": "Hệ thống có những tính năng gì",
|
| 15 |
+
"history": [
|
| 16 |
+
{"content": "Bạn là ai vậy", "type": "human"},
|
| 17 |
+
{
|
| 18 |
+
"content": "Tôi là AI hỗ trợ cho hệ thống LearnMigo",
|
| 19 |
+
"type": "ai",
|
| 20 |
+
},
|
| 21 |
+
],
|
| 22 |
+
"language": "Vietnamese",
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class HighlightExplainBody(BaseModel):
|
| 29 |
+
domain: str = Field(..., title="Domain")
|
| 30 |
+
question: str = Field(..., title="User's query messages")
|
| 31 |
+
highlight_terms: str = Field(..., title="Highlight terms")
|
| 32 |
+
adjacent_paragraphs: str = Field(..., title="Adjacent paragraphs")
|
| 33 |
+
language: str = Field("Vietnamese", title="Language")
|
| 34 |
+
model_config = {
|
| 35 |
+
"json_schema_extra": {
|
| 36 |
+
"example": {
|
| 37 |
+
"language": "Vietnamese",
|
| 38 |
+
"domain": "Machine Learning",
|
| 39 |
+
"question": "What does overfitting mean and why is it a problem?",
|
| 40 |
+
"highlight_terms": "overfitting",
|
| 41 |
+
"adjacent_paragraphs": "Overfitting happens when a machine learning model performs well on the training data but poorly on unseen data. This is because the model has learned not just the underlying patterns but also the noise in the training dataset. In contrast, a well-generalized model captures patterns that apply to new data as well.",
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
}
|
src/apis/routers/__pycache__/chat_router.cpython-311.pyc
ADDED
|
Binary file (2.76 kB). View file
|
|
|
src/apis/routers/chat_router.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, status, Depends
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
from typing import Annotated
|
| 4 |
+
from src.apis.interfaces.chat_interface import ChatBody, HighlightExplainBody
|
| 5 |
+
from src.agents.primary_chatbot.flow import rag_accuracy, rag_speed
|
| 6 |
+
from src.agents.highlight_explain_agent.flow import highlight_workflow
|
| 7 |
+
|
| 8 |
+
router = APIRouter(prefix="/ai", tags=["AI"])
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@router.post("/rag_accuracy")
|
| 12 |
+
async def primary_chat_accuracy(body: ChatBody):
|
| 13 |
+
response = await rag_accuracy.ainvoke(
|
| 14 |
+
{
|
| 15 |
+
"user_query": body.query,
|
| 16 |
+
"messages_history": body.history,
|
| 17 |
+
"language": body.language,
|
| 18 |
+
}
|
| 19 |
+
)
|
| 20 |
+
final_response = response["llm_response"]
|
| 21 |
+
return JSONResponse(status_code=status.HTTP_200_OK, content=final_response)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@router.post("/rag_speed")
|
| 25 |
+
async def primary_chat_speed(body: ChatBody):
|
| 26 |
+
response = await rag_speed.ainvoke(
|
| 27 |
+
{
|
| 28 |
+
"user_query": body.query,
|
| 29 |
+
"messages_history": body.history,
|
| 30 |
+
"language": body.language,
|
| 31 |
+
}
|
| 32 |
+
)
|
| 33 |
+
final_response = response["llm_response"]
|
| 34 |
+
return JSONResponse(status_code=status.HTTP_200_OK, content=final_response)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@router.post("/highlight_explain")
|
| 38 |
+
async def highlight_explain(body: HighlightExplainBody):
|
| 39 |
+
response = await highlight_workflow.ainvoke(
|
| 40 |
+
{
|
| 41 |
+
"domain": body.domain,
|
| 42 |
+
"question": body.question,
|
| 43 |
+
"highlight_terms": body.highlight_terms,
|
| 44 |
+
"adjacent_paragraphs": body.adjacent_paragraphs,
|
| 45 |
+
"language": body.language,
|
| 46 |
+
}
|
| 47 |
+
)
|
| 48 |
+
final_response = response["explanation"]
|
| 49 |
+
return JSONResponse(status_code=status.HTTP_200_OK, content=final_response)
|
src/config/__pycache__/llm.cpython-311.pyc
ADDED
|
Binary file (674 Bytes). View file
|
|
|
src/config/__pycache__/prompt.cpython-311.pyc
ADDED
|
Binary file (8.98 kB). View file
|
|
|
src/config/__pycache__/vector_store.cpython-311.pyc
ADDED
|
Binary file (868 Bytes). View file
|
|
|
src/config/constant.py
ADDED
|
File without changes
|
src/config/llm.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 2 |
+
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
| 3 |
+
|
| 4 |
+
llm_2_0 = ChatGoogleGenerativeAI(
|
| 5 |
+
model="gemini-2.0-flash",
|
| 6 |
+
temperature=0.1,
|
| 7 |
+
max_retries=2,
|
| 8 |
+
)
|
| 9 |
+
llm_1_5 = ChatGoogleGenerativeAI(
|
| 10 |
+
model="gemini-1.5-flash",
|
| 11 |
+
temperature=0.1,
|
| 12 |
+
max_retries=2,
|
| 13 |
+
)
|
| 14 |
+
embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
|
src/config/prompt.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 3 |
+
from typing import Literal
|
| 4 |
+
from src.config.llm import llm_2_0 as llm
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RouteQuery(BaseModel):
|
| 9 |
+
"""Route a user query to the most relevant datasource."""
|
| 10 |
+
|
| 11 |
+
datasource: Literal["vectorstore", "casual_convo"] = Field(
|
| 12 |
+
...,
|
| 13 |
+
description="Given a user question choose to route it to casual_convo or a vectorstore.",
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ExtractFilter(BaseModel):
|
| 18 |
+
"""Extract job level and job title from user question."""
|
| 19 |
+
|
| 20 |
+
job_level: str = Field(description="The level of the job the user is asking about.")
|
| 21 |
+
job_title: str = Field(description="The title of the job the user is asking about.")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class GradeDocuments(BaseModel):
|
| 25 |
+
"""Binary score for relevance check on retrieved documents."""
|
| 26 |
+
|
| 27 |
+
binary_score: str = Field(
|
| 28 |
+
description="Documents are relevant to the question, 'yes' or 'no'"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class GenerateAnswer(BaseModel):
|
| 33 |
+
"""Generate an answer based on the provided documents."""
|
| 34 |
+
|
| 35 |
+
answer: str = Field(description="Generated answer based on the provided documents.")
|
| 36 |
+
selected_document_index: Optional[list[int]] = Field(
|
| 37 |
+
description="Index of the selected document. If not have relevant document then leave it None"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class GradeHallucinations(BaseModel):
|
| 42 |
+
"""Binary score for grounding of generation answer in provided facts."""
|
| 43 |
+
|
| 44 |
+
binary_score: Literal["yes", "no"] = Field(
|
| 45 |
+
description="Whether the answer is grounded in the provided facts. 'yes' if the answer is supported by facts, 'no' if the answer contains information not present or contradicting the given facts"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class HighlightExplain(BaseModel):
|
| 50 |
+
"""Explain the highlight terms in a concise and easy to understand manner."""
|
| 51 |
+
|
| 52 |
+
explanation: str = Field(description="Explanation of the highlight terms.")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
route_prompt = ChatPromptTemplate(
|
| 56 |
+
[
|
| 57 |
+
(
|
| 58 |
+
"system",
|
| 59 |
+
"""You are an expert at routing the user's question to vectorstore or casual_convo in {topic} platform.
|
| 60 |
+
choose vectorstore if the question is related to {topic} and casual_convo otherwise. \n
|
| 61 |
+
|
| 62 |
+
example:
|
| 63 |
+
user: Hi are you [this is a random question not related to {topic} so route to casual_convo] : casual_convo
|
| 64 |
+
user: Calculate,... [this question is related to education, system information so route to vectorstore] : vectorstore""",
|
| 65 |
+
),
|
| 66 |
+
("placeholder", "{history}"),
|
| 67 |
+
("human", "{question}"),
|
| 68 |
+
]
|
| 69 |
+
).partial(topic="education")
|
| 70 |
+
|
| 71 |
+
re_write_query_prompt = ChatPromptTemplate(
|
| 72 |
+
[
|
| 73 |
+
(
|
| 74 |
+
"system",
|
| 75 |
+
"""You a question re-writer that converts an input question to a better version that is optimized
|
| 76 |
+
for vectorstore retrieval, and very concise. Look at the input and try to reason about the underlying semantic intent/meaning. The input can also be a
|
| 77 |
+
follow up question, look at the chat history to re-write the question to include necessary info from the chat history to a better version that is optimized
|
| 78 |
+
for vectorstore retrieval without any other info needed. [the topic of convo will be generally around {topic} topic. You need to re-write query base on history and include keyword related to this topic""",
|
| 79 |
+
),
|
| 80 |
+
("placeholder", "{history}"),
|
| 81 |
+
(
|
| 82 |
+
"human",
|
| 83 |
+
"{question}",
|
| 84 |
+
),
|
| 85 |
+
]
|
| 86 |
+
).partial(topic="education")
|
| 87 |
+
|
| 88 |
+
extract_filter_prompt = ChatPromptTemplate.from_messages(
|
| 89 |
+
[
|
| 90 |
+
(
|
| 91 |
+
"system",
|
| 92 |
+
"""You are an expert at extracting metadata from the user's question about {topic} topic and using it to filter the retrieved documents.
|
| 93 |
+
""",
|
| 94 |
+
),
|
| 95 |
+
("placeholder", "{history}"),
|
| 96 |
+
("human", "{question}"),
|
| 97 |
+
]
|
| 98 |
+
).partial(topic="education")
|
| 99 |
+
|
| 100 |
+
check_relevant_document_prompt = ChatPromptTemplate(
|
| 101 |
+
[
|
| 102 |
+
(
|
| 103 |
+
"system",
|
| 104 |
+
"""
|
| 105 |
+
You are a grader assessing relevance of a retrieved document to a user question.
|
| 106 |
+
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
|
| 107 |
+
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
|
| 108 |
+
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
| 109 |
+
Then, give a score ranges from 0 to 1, with higher values indicating a stronger match and the more corresponding keywords.
|
| 110 |
+
""",
|
| 111 |
+
),
|
| 112 |
+
(
|
| 113 |
+
"human",
|
| 114 |
+
"Retrieved document: \n\n {document} \nvs\n User question: {question}",
|
| 115 |
+
),
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
gen_answer_rag_prompt = ChatPromptTemplate(
|
| 120 |
+
[
|
| 121 |
+
(
|
| 122 |
+
"system",
|
| 123 |
+
"""You are chat bot related to {topic}. You are asked to generate an answer based on the provided documents.
|
| 124 |
+
Your are given context related to job description of a job position. If the context not provided, you just say 'không có tài liệu liên quan'
|
| 125 |
+
Answer in {language} language.
|
| 126 |
+
|
| 127 |
+
Context:
|
| 128 |
+
```
|
| 129 |
+
{context}
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
""",
|
| 133 |
+
),
|
| 134 |
+
(
|
| 135 |
+
"human",
|
| 136 |
+
"""
|
| 137 |
+
Question: {question}
|
| 138 |
+
""",
|
| 139 |
+
),
|
| 140 |
+
]
|
| 141 |
+
).partial(topic="education", language="vietnamese")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
grade_answer_prompt = ChatPromptTemplate(
|
| 145 |
+
[
|
| 146 |
+
(
|
| 147 |
+
"system",
|
| 148 |
+
"""You are a grader assessing whether an answer addresses / resolves a question \n
|
| 149 |
+
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.
|
| 150 |
+
If the LLM Generation is saying that it doesnt know or not sure or stating to keep the questions relevant to topic , grade it as 'yes'.""",
|
| 151 |
+
),
|
| 152 |
+
(
|
| 153 |
+
"human",
|
| 154 |
+
"If the LLM Generation is saying that it doesnt know or not sure or stating to keep the questions relevant to topic , grade it as 'yes'. User question: \n\n {question} \n\n LLM generation: {generation}",
|
| 155 |
+
),
|
| 156 |
+
]
|
| 157 |
+
)
|
| 158 |
+
gen_normal_answer_prompt = ChatPromptTemplate(
|
| 159 |
+
[
|
| 160 |
+
(
|
| 161 |
+
"system",
|
| 162 |
+
"""Bạn là chatbot giải đáp câu hỏi của người dùng dựa trên đoạn hội thoại liên quan đến lĩnh vực giáo dục
|
| 163 |
+
""",
|
| 164 |
+
),
|
| 165 |
+
("placeholder", "{history}"),
|
| 166 |
+
("human", "{question}"),
|
| 167 |
+
]
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
highlight_explain_prompt = ChatPromptTemplate(
|
| 171 |
+
[
|
| 172 |
+
(
|
| 173 |
+
"system",
|
| 174 |
+
"""You are a expert in explaining the highlight terms in {domain} domain.
|
| 175 |
+
You are given the higlight terms, adjacent paragraphs of the highlight terms.
|
| 176 |
+
Your task is to explain the highlight terms in a concise and easy to understand manner.
|
| 177 |
+
You are also given the user question.
|
| 178 |
+
""",
|
| 179 |
+
),
|
| 180 |
+
(
|
| 181 |
+
"human",
|
| 182 |
+
"""
|
| 183 |
+
User question: {question}
|
| 184 |
+
Highlight terms: {highlight_terms}
|
| 185 |
+
Adjacent paragraphs: {adjacent_paragraphs}
|
| 186 |
+
""",
|
| 187 |
+
),
|
| 188 |
+
]
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
route_chain = route_prompt | llm.with_structured_output(RouteQuery)
|
| 192 |
+
transform_query_chain = re_write_query_prompt | llm
|
| 193 |
+
extract_filter_chain = extract_filter_prompt | llm.with_structured_output(ExtractFilter)
|
| 194 |
+
grade_documents_chain = check_relevant_document_prompt | llm.with_structured_output(
|
| 195 |
+
GradeDocuments
|
| 196 |
+
)
|
| 197 |
+
gen_answer_rag_chain = gen_answer_rag_prompt | llm.with_structured_output(
|
| 198 |
+
GenerateAnswer
|
| 199 |
+
)
|
| 200 |
+
gen_normal_answer_chain = gen_normal_answer_prompt | llm
|
| 201 |
+
grade_hallucinations_chain = grade_answer_prompt | llm.with_structured_output(
|
| 202 |
+
GradeHallucinations
|
| 203 |
+
)
|
| 204 |
+
highlight_explain_chain = highlight_explain_prompt | llm.with_structured_output(
|
| 205 |
+
HighlightExplain
|
| 206 |
+
)
|
src/config/vector_store.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_mongodb import MongoDBAtlasVectorSearch
|
| 2 |
+
from pymongo import MongoClient
|
| 3 |
+
from .llm import embeddings
|
| 4 |
+
import os
|
| 5 |
+
from langchain_pinecone import PineconeVectorStore
|
| 6 |
+
|
| 7 |
+
# client = MongoClient(os.getenv("MONGO_CONNECTION_STR"))
|
| 8 |
+
|
| 9 |
+
# DB_NAME = os.getenv("DB_NAME")
|
| 10 |
+
# COLLECTION_NAME = os.getenv("COLLECTION_NAME")
|
| 11 |
+
# ATLAS_VECTOR_CHATBOT_INDEX_NAME = os.getenv("ATLAS_VECTOR_CHATBOT_INDEX_NAME")
|
| 12 |
+
# ATLAS_VECTOR_TUTOR_INDEX_NAME = os.getenv("ATLAS_VECTOR_TUTOR_INDEX_NAME")
|
| 13 |
+
|
| 14 |
+
# MONGODB_COLLECTION_CHATBOT = client[DB_NAME][ATLAS_VECTOR_CHATBOT_INDEX_NAME]
|
| 15 |
+
# MONGODB_COLLECTION_TUTOR = client[DB_NAME][ATLAS_VECTOR_TUTOR_INDEX_NAME]
|
| 16 |
+
|
| 17 |
+
# vector_store_chatbot = MongoDBAtlasVectorSearch(
|
| 18 |
+
# collection=MONGODB_COLLECTION_CHATBOT,
|
| 19 |
+
# embedding=embeddings,
|
| 20 |
+
# index_name=ATLAS_VECTOR_CHATBOT_INDEX_NAME,
|
| 21 |
+
# relevance_score_fn="cosine",
|
| 22 |
+
# )
|
| 23 |
+
# vector_store_tutor = MongoDBAtlasVectorSearch(
|
| 24 |
+
# collection=MONGODB_COLLECTION_TUTOR,
|
| 25 |
+
# embedding=embeddings,
|
| 26 |
+
# index_name=ATLAS_VECTOR_TUTOR_INDEX_NAME,
|
| 27 |
+
# relevance_score_fn="cosine",
|
| 28 |
+
# )
|
| 29 |
+
API_PINCONE_KEY = os.getenv("PINECONE_API_KEY")
|
| 30 |
+
index_tutor = "tutor-vector-store"
|
| 31 |
+
index_chatbot = "chatbot-vector-store"
|
| 32 |
+
vector_store_tutor = PineconeVectorStore(
|
| 33 |
+
index_name=index_tutor, embedding=embeddings, pinecone_api_key=API_PINCONE_KEY
|
| 34 |
+
)
|
| 35 |
+
vector_store_chatbot = PineconeVectorStore(
|
| 36 |
+
index_name=index_chatbot, embedding=embeddings, pinecone_api_key=API_PINCONE_KEY
|
| 37 |
+
)
|
src/utils/__pycache__/helper.cpython-311.pyc
ADDED
|
Binary file (2.24 kB). View file
|
|
|
src/utils/__pycache__/logger.cpython-311.pyc
ADDED
|
Binary file (3.92 kB). View file
|
|
|
src/utils/helper.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.documents import Document
|
| 2 |
+
from typing import Union
|
| 3 |
+
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def fake_token_counter(messages: Union[list[BaseMessage], BaseMessage]) -> int:
|
| 7 |
+
if isinstance(messages, list):
|
| 8 |
+
return sum(len(message.content.split()) for message in messages)
|
| 9 |
+
return len(messages.content.split())
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def convert_list_context_source_to_str(contexts: list[Document]):
|
| 13 |
+
formatted_str = ""
|
| 14 |
+
for i, context in enumerate(contexts):
|
| 15 |
+
formatted_str += f"Document index {i}:\nContent: {context.page_content}\n"
|
| 16 |
+
formatted_str += "----------------------------------------------\n\n"
|
| 17 |
+
return formatted_str
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def convert_message(messages):
|
| 21 |
+
list_message = []
|
| 22 |
+
for message in messages:
|
| 23 |
+
if message["type"] == "human":
|
| 24 |
+
list_message.append(HumanMessage(content=message["content"]))
|
| 25 |
+
else:
|
| 26 |
+
list_message.append(AIMessage(content=message["content"]))
|
| 27 |
+
return list_message
|
src/utils/logger.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import pytz
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CoreCFG:
|
| 10 |
+
PROJECT_NAME = "SCHEDULE AI"
|
| 11 |
+
BOT_NAME = str("SCHEDULE AI")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_date_time():
|
| 15 |
+
return datetime.now(pytz.timezone("Asia/Ho_Chi_Minh"))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
DATE_TIME = get_date_time().date()
|
| 19 |
+
BASE_DIR = os.path.dirname(Path(__file__).parent.parent)
|
| 20 |
+
LOG_DIR = os.path.join(BASE_DIR, "logs")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CustomFormatter(logging.Formatter):
|
| 24 |
+
green = "\x1b[0;32m"
|
| 25 |
+
grey = "\x1b[38;5;248m"
|
| 26 |
+
yellow = "\x1b[38;5;229m"
|
| 27 |
+
red = "\x1b[31;20m"
|
| 28 |
+
bold_red = "\x1b[31;1m"
|
| 29 |
+
blue = "\x1b[38;5;31m"
|
| 30 |
+
white = "\x1b[38;5;255m"
|
| 31 |
+
reset = "\x1b[38;5;15m"
|
| 32 |
+
|
| 33 |
+
base_format = f"{grey}%(asctime)s | %(name)s | %(threadName)s | {{level_color}}%(levelname)-8s{grey} | {blue}%(module)s:%(lineno)d{grey} - {white}%(message)s"
|
| 34 |
+
|
| 35 |
+
FORMATS = {
|
| 36 |
+
logging.INFO: base_format.format(level_color=green),
|
| 37 |
+
logging.WARNING: base_format.format(level_color=yellow),
|
| 38 |
+
logging.ERROR: base_format.format(level_color=red),
|
| 39 |
+
logging.CRITICAL: base_format.format(level_color=bold_red),
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
def format(self, record):
|
| 43 |
+
log_fmt = self.FORMATS.get(record.levelno)
|
| 44 |
+
formatter = logging.Formatter(log_fmt)
|
| 45 |
+
return formatter.format(record)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def custom_logger(app_name="APP"):
|
| 49 |
+
logger_r = logging.getLogger(name=app_name)
|
| 50 |
+
# Set the timezone to Ho_Chi_Minh
|
| 51 |
+
tz = pytz.timezone("Asia/Ho_Chi_Minh")
|
| 52 |
+
|
| 53 |
+
logging.Formatter.converter = lambda *args: datetime.now(tz).timetuple()
|
| 54 |
+
|
| 55 |
+
ch = logging.StreamHandler()
|
| 56 |
+
ch.setLevel(logging.INFO)
|
| 57 |
+
ch.setFormatter(CustomFormatter())
|
| 58 |
+
|
| 59 |
+
logger_r.setLevel(logging.INFO)
|
| 60 |
+
logger_r.addHandler(ch)
|
| 61 |
+
|
| 62 |
+
return logger_r
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
logger = custom_logger(app_name=CoreCFG.PROJECT_NAME)
|