ABAO77 commited on
Commit
031378e
·
verified ·
1 Parent(s): 55e58da

Upload 72 files

Browse files
src/agents/custom_chatbot/__pycache__/func.cpython-311.pyc CHANGED
Binary files a/src/agents/custom_chatbot/__pycache__/func.cpython-311.pyc and b/src/agents/custom_chatbot/__pycache__/func.cpython-311.pyc differ
 
src/agents/custom_chatbot/func.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import TypedDict,Optional
2
  from langchain_core.messages import AnyMessage, ToolMessage
3
  from langgraph.graph.message import add_messages
4
  from typing import Sequence, Annotated
@@ -26,7 +26,7 @@ def get_info_collection(messages):
26
 
27
 
28
  async def collection_info_agent(state: State):
29
- model_name = state.get("model_name", "gpt-4o")
30
  _, collection_info_agent = get_custom_chatbot_chains(model_name)
31
  return await collection_info_agent.ainvoke(state)
32
 
@@ -35,7 +35,7 @@ async def create_prompt(state: State):
35
  messages = state.get("messages")
36
  name, info = get_info_collection(messages)
37
  logger.info(f"create_prompt {info}")
38
- model_name = state.get("model_name", "gpt-4o")
39
  create_system_chain, _ = get_custom_chatbot_chains(model_name)
40
  res = await create_system_chain.ainvoke({"info": info})
41
  return {"prompt": res.content, "name": name}
 
1
+ from typing import TypedDict, Optional
2
  from langchain_core.messages import AnyMessage, ToolMessage
3
  from langgraph.graph.message import add_messages
4
  from typing import Sequence, Annotated
 
26
 
27
 
28
  async def collection_info_agent(state: State):
29
+ model_name = state.get("model_name")
30
  _, collection_info_agent = get_custom_chatbot_chains(model_name)
31
  return await collection_info_agent.ainvoke(state)
32
 
 
35
  messages = state.get("messages")
36
  name, info = get_info_collection(messages)
37
  logger.info(f"create_prompt {info}")
38
+ model_name = state.get("model_name")
39
  create_system_chain, _ = get_custom_chatbot_chains(model_name)
40
  res = await create_system_chain.ainvoke({"info": info})
41
  return {"prompt": res.content, "name": name}
src/agents/rag_agent_template/__pycache__/func.cpython-311.pyc CHANGED
Binary files a/src/agents/rag_agent_template/__pycache__/func.cpython-311.pyc and b/src/agents/rag_agent_template/__pycache__/func.cpython-311.pyc differ
 
src/agents/rag_agent_template/__pycache__/prompt.cpython-311.pyc CHANGED
Binary files a/src/agents/rag_agent_template/__pycache__/prompt.cpython-311.pyc and b/src/agents/rag_agent_template/__pycache__/prompt.cpython-311.pyc differ
 
src/agents/rag_agent_template/__pycache__/tools.cpython-311.pyc CHANGED
Binary files a/src/agents/rag_agent_template/__pycache__/tools.cpython-311.pyc and b/src/agents/rag_agent_template/__pycache__/tools.cpython-311.pyc differ
 
src/agents/rag_agent_template/func.py CHANGED
@@ -1,27 +1,30 @@
1
  from typing import TypedDict, Optional, List
2
  from langchain_core.messages import AnyMessage, ToolMessage
3
  from langgraph.graph.message import add_messages
4
- from .prompt import get_rag_chains
5
  from typing import Sequence, Annotated
6
  from langchain_core.messages import RemoveMessage
7
  from langchain_core.documents import Document
8
- from .tools import retrieve_document
9
  from src.utils.logger import logger
 
 
10
 
11
- tools = [retrieve_document]
12
 
13
 
14
  class State(TypedDict):
15
  messages: Annotated[Sequence[AnyMessage], add_messages]
16
  selected_ids: Optional[List[str]]
17
  selected_documents: Optional[List[Document]]
18
- tools: list
19
  prompt: str
20
  model_name: Optional[str]
21
 
22
 
23
  def trim_history(state: State):
24
  history = state.get("messages", [])
 
 
25
  if len(history) > 10:
26
  num_to_remove = len(history) - 10
27
  remove_messages = [
@@ -38,7 +41,10 @@ def trim_history(state: State):
38
 
39
  def execute_tool(state: State):
40
  tool_calls = state["messages"][-1].tool_calls
 
41
  tool_name_to_func = {tool.name: tool for tool in tools}
 
 
42
  selected_ids = []
43
  selected_documents = []
44
  tool_messages = []
@@ -64,7 +70,11 @@ def execute_tool(state: State):
64
  )
65
  continue
66
  tool_response = tool_func.invoke(tool_args)
67
- tool_messages.append(tool_response)
 
 
 
 
68
 
69
  return {
70
  "selected_ids": selected_ids,
@@ -75,22 +85,30 @@ def execute_tool(state: State):
75
 
76
  def generate_answer_rag(state: State):
77
  messages = state["messages"]
78
- tools = state["tools"]
79
- model_name = state.get("model_name", "gemini-2.0-flash") # default to gemini-2.0-flash
80
- rag_answering_chain_tool, rag_answering_chain = get_rag_chains(model_name)
81
- logger.info(f"tools: {tools}")
82
- if tools:
83
- response = rag_answering_chain_tool.invoke(
84
- {
85
- "messages": messages,
86
- "prompt": state["prompt"] + "Sử dụng tool `retrieve_document` để truy xuất tài liệu để bổ sung thông tin cho câu trả lời",
87
- }
88
- )
89
- else:
90
- response = rag_answering_chain.invoke(
91
- {
92
- "messages": messages,
93
- "prompt": state["prompt"],
94
- }
95
- )
 
 
 
 
 
 
 
 
96
  return {"messages": response}
 
1
  from typing import TypedDict, Optional, List
2
  from langchain_core.messages import AnyMessage, ToolMessage
3
  from langgraph.graph.message import add_messages
 
4
  from typing import Sequence, Annotated
5
  from langchain_core.messages import RemoveMessage
6
  from langchain_core.documents import Document
7
+ from .tools import retrieve_document, python_repl, duckduckgo_search
8
  from src.utils.logger import logger
9
+ from src.config.llm import get_llm
10
+ from .prompt import template_prompt
11
 
12
+ tools = [retrieve_document, python_repl, duckduckgo_search]
13
 
14
 
15
  class State(TypedDict):
16
  messages: Annotated[Sequence[AnyMessage], add_messages]
17
  selected_ids: Optional[List[str]]
18
  selected_documents: Optional[List[Document]]
19
+ tools: Optional[List[str]]
20
  prompt: str
21
  model_name: Optional[str]
22
 
23
 
24
  def trim_history(state: State):
25
  history = state.get("messages", [])
26
+ tool_names = state.get("tools", [])
27
+
28
  if len(history) > 10:
29
  num_to_remove = len(history) - 10
30
  remove_messages = [
 
41
 
42
  def execute_tool(state: State):
43
  tool_calls = state["messages"][-1].tool_calls
44
+ tool_names = state.get("tools", [])
45
  tool_name_to_func = {tool.name: tool for tool in tools}
46
+ tool_functions = [tool_name_to_func[name] for name in tool_names if name in tool_name_to_func]
47
+
48
  selected_ids = []
49
  selected_documents = []
50
  tool_messages = []
 
70
  )
71
  continue
72
  tool_response = tool_func.invoke(tool_args)
73
+ print(f"tool_response: {tool_response}")
74
+ tool_messages.append(ToolMessage(
75
+ tool_call_id=tool_id,
76
+ content=tool_response,
77
+ ))
78
 
79
  return {
80
  "selected_ids": selected_ids,
 
85
 
86
  def generate_answer_rag(state: State):
87
  messages = state["messages"]
88
+ tool_names = state.get("tools", [])
89
+ prompt = state["prompt"]
90
+ model_name = state.get("model_name", "gemini-2.0-flash")
91
+
92
+ tool_name_to_func = {tool.name: tool for tool in tools}
93
+ tool_functions = [tool_name_to_func[name] for name in tool_names if name in tool_name_to_func]
94
+
95
+ print(f"tools: {tool_functions}")
96
+ llm_call = template_prompt | get_llm(model_name).bind_tools(tool_functions)
97
+
98
+ if tool_functions:
99
+ for tool in tool_functions:
100
+ if tool.name == "retrieve_document":
101
+ prompt += "Sử dụng tool `retrieve_document` để truy xuất tài liệu để bổ sung thông tin cho câu trả lời"
102
+ if tool.name == "python_repl":
103
+ prompt += "Sử dụng tool `python_repl` để thực hiện các tác vụ liên quan đến tính toán phức tạp"
104
+ if tool.name == "duckduckgo_search":
105
+ prompt += "Sử dụng tool `duckduckgo_search` để tìm kiếm thông tin trên internet"
106
+
107
+ response = llm_call.invoke(
108
+ {
109
+ "messages": messages,
110
+ "prompt": prompt,
111
+ }
112
+ )
113
+
114
  return {"messages": response}
src/agents/rag_agent_template/prompt.py CHANGED
@@ -2,17 +2,9 @@ from langchain_core.prompts import ChatPromptTemplate
2
  from src.config.llm import get_llm
3
  from .tools import retrieve_document
4
 
5
- rag_prompt = ChatPromptTemplate.from_messages(
6
  [
7
  ("system", "{prompt}"),
8
  ("placeholder", "{messages}"),
9
  ]
10
  )
11
-
12
-
13
- def get_rag_chains(model_name: str):
14
- llm = get_llm(model_name)
15
- llm_rag = llm.bind_tools([retrieve_document])
16
- rag_answering_chain_tool = rag_prompt | llm_rag
17
- rag_answering_chain = rag_prompt | llm
18
- return rag_answering_chain_tool, rag_answering_chain
 
2
  from src.config.llm import get_llm
3
  from .tools import retrieve_document
4
 
5
+ template_prompt = ChatPromptTemplate.from_messages(
6
  [
7
  ("system", "{prompt}"),
8
  ("placeholder", "{messages}"),
9
  ]
10
  )
 
 
 
 
 
 
 
 
src/agents/rag_agent_template/tools.py CHANGED
@@ -3,12 +3,18 @@ from src.config.vector_store import test_rag_vector_store
3
  from src.utils.helper import convert_list_context_source_to_str
4
  from src.utils.logger import logger
5
  from langchain_core.runnables import RunnableConfig
 
 
 
 
 
 
 
6
 
7
 
8
  @tool
9
  def retrieve_document(query: str, config: RunnableConfig):
10
  """Ưu tiên truy xuất tài liệu từ vector store nếu câu hỏi liên quan đến vai trò của chatbot.
11
-
12
 
13
  Args:
14
  query (str): Câu truy vấn của người dùng bằng tiếng Việt
@@ -34,3 +40,16 @@ def retrieve_document(query: str, config: RunnableConfig):
34
  "selected_documents": selected_documents,
35
  "selected_ids": selected_ids,
36
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from src.utils.helper import convert_list_context_source_to_str
4
  from src.utils.logger import logger
5
  from langchain_core.runnables import RunnableConfig
6
+ from langchain_experimental.utilities import PythonREPL
7
+ from langchain_community.tools import DuckDuckGoSearchRun
8
+
9
+
10
+ duckduckgo_search = DuckDuckGoSearchRun()
11
+
12
+ python_exec = PythonREPL()
13
 
14
 
15
  @tool
16
  def retrieve_document(query: str, config: RunnableConfig):
17
  """Ưu tiên truy xuất tài liệu từ vector store nếu câu hỏi liên quan đến vai trò của chatbot.
 
18
 
19
  Args:
20
  query (str): Câu truy vấn của người dùng bằng tiếng Việt
 
40
  "selected_documents": selected_documents,
41
  "selected_ids": selected_ids,
42
  }
43
+
44
+
45
+ @tool
46
+ def python_repl(code: str):
47
+ """
48
+ A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.
49
+
50
+ Args:
51
+ code (str): Python code to execute
52
+ Returns:
53
+ str: Output of the Python code
54
+ """
55
+ return python_exec.run(code)
src/apis/__pycache__/create_app.cpython-311.pyc CHANGED
Binary files a/src/apis/__pycache__/create_app.cpython-311.pyc and b/src/apis/__pycache__/create_app.cpython-311.pyc differ
 
src/apis/create_app.py CHANGED
@@ -4,13 +4,13 @@ from src.apis.routers.rag_agent_template import router as router_rag_agent_templ
4
  from src.apis.routers.file_processing_router import router as router_file_processing
5
  from src.apis.routers.custom_chatbot_router import router as custom_chatbot_processing
6
  from src.apis.routers.vector_store_router import router as vector_store_router
7
-
8
  api_router = APIRouter()
9
  api_router.include_router(router_rag_agent_template)
10
  api_router.include_router(router_file_processing)
11
  api_router.include_router(custom_chatbot_processing)
12
  api_router.include_router(vector_store_router)
13
-
14
  def create_app():
15
  app = FastAPI(
16
  docs_url="/",
 
4
  from src.apis.routers.file_processing_router import router as router_file_processing
5
  from src.apis.routers.custom_chatbot_router import router as custom_chatbot_processing
6
  from src.apis.routers.vector_store_router import router as vector_store_router
7
+ from src.apis.routers.tts_router import router as tts_router
8
  api_router = APIRouter()
9
  api_router.include_router(router_rag_agent_template)
10
  api_router.include_router(router_file_processing)
11
  api_router.include_router(custom_chatbot_processing)
12
  api_router.include_router(vector_store_router)
13
+ api_router.include_router(tts_router)
14
  def create_app():
15
  app = FastAPI(
16
  docs_url="/",
src/apis/routers/__pycache__/file_processing_router.cpython-311.pyc CHANGED
Binary files a/src/apis/routers/__pycache__/file_processing_router.cpython-311.pyc and b/src/apis/routers/__pycache__/file_processing_router.cpython-311.pyc differ
 
src/apis/routers/__pycache__/tts.cpython-311.pyc ADDED
Binary file (2.29 kB). View file
 
src/apis/routers/__pycache__/tts_router.cpython-311.pyc ADDED
Binary file (8.23 kB). View file
 
src/apis/routers/tts_router.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import VitsModel, AutoTokenizer
4
+ import torch
5
+ import os
6
+ import uuid
7
+ from fastapi.responses import FileResponse
8
+ import soundfile as sf
9
+ import numpy as np
10
+ from src.utils.logger import logger
11
+ from google import genai
12
+ from google.genai import types
13
+ import wave
14
+ from typing import Literal
15
+
16
+ router = APIRouter()
17
+
18
+
19
+ class TTSRequest(BaseModel):
20
+ text: str
21
+
22
+
23
+ class GeminiTTSRequest(BaseModel):
24
+ text: str
25
+ voice_name: str = "Kore"
26
+
27
+
28
+ # Initialize model and tokenizer globally
29
+ try:
30
+ logger.info("Loading TTS model and tokenizer...")
31
+ model = VitsModel.from_pretrained("facebook/mms-tts-vie")
32
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-vie")
33
+ logger.info("TTS model and tokenizer loaded successfully")
34
+ except Exception as e:
35
+ logger.error(f"Failed to load TTS model: {str(e)}")
36
+ raise
37
+
38
+ # Initialize Google Gemini client
39
+ try:
40
+ logger.info("Initializing Google Gemini client...")
41
+ gemini_client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
42
+ logger.info("Google Gemini client initialized successfully")
43
+ except Exception as e:
44
+ logger.error(f"Failed to initialize Google Gemini client: {str(e)}")
45
+ raise
46
+
47
+
48
+ def save_wave_file(
49
+ filename: str,
50
+ pcm: bytes,
51
+ channels: int = 1,
52
+ rate: int = 24000,
53
+ sample_width: int = 2,
54
+ ):
55
+ """Save PCM data to a WAV file."""
56
+ with wave.open(filename, "wb") as wf:
57
+ wf.setnchannels(channels)
58
+ wf.setsampwidth(sample_width)
59
+ wf.setframerate(rate)
60
+ wf.writeframes(pcm)
61
+
62
+
63
+ @router.post("/tts/huggingface")
64
+ async def huggingface_tts(request: TTSRequest):
65
+ try:
66
+ logger.info(
67
+ f"Processing HuggingFace TTS request for text: {request.text[:50]}..."
68
+ )
69
+
70
+ # Tokenize input
71
+ inputs = tokenizer(request.text, return_tensors="pt")
72
+ logger.info("Text tokenized successfully")
73
+
74
+ # Generate audio
75
+ with torch.no_grad():
76
+ output = model(**inputs).waveform
77
+ logger.info("Audio generated successfully")
78
+
79
+ # Convert tensor to numpy array
80
+ audio_numpy = output.squeeze().cpu().numpy()
81
+
82
+ # Create audio directory if it doesn't exist
83
+ audio_dir = os.path.join(os.getcwd(), "audio_files")
84
+ os.makedirs(audio_dir, exist_ok=True)
85
+
86
+ # Generate unique filename
87
+ audio_filename = f"huggingface_{uuid.uuid4()}.wav"
88
+ audio_path = os.path.join(audio_dir, audio_filename)
89
+
90
+ # Save audio file using soundfile
91
+ sf.write(audio_path, audio_numpy, model.config.sampling_rate)
92
+ logger.info(f"Audio saved to {audio_path}")
93
+
94
+ # Return audio file
95
+ return FileResponse(audio_path, media_type="audio/wav", filename=audio_filename)
96
+
97
+ except Exception as e:
98
+ logger.error(f"Error in huggingface_tts: {str(e)}")
99
+ raise HTTPException(
100
+ status_code=500, detail=f"Failed to generate speech: {str(e)}"
101
+ )
102
+
103
+
104
+ @router.post("/tts/gemini")
105
+ def gemini_tts(request: GeminiTTSRequest):
106
+ try:
107
+ logger.info(f"Processing Gemini TTS request for text: {request.text[:50]}...")
108
+
109
+ response = gemini_client.models.generate_content(
110
+ model="gemini-2.5-flash-preview-tts",
111
+ contents=request.text,
112
+ config=types.GenerateContentConfig(
113
+ response_modalities=["AUDIO"],
114
+ speech_config=types.SpeechConfig(
115
+ voice_config=types.VoiceConfig(
116
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(
117
+ voice_name=request.voice_name,
118
+ )
119
+ )
120
+ ),
121
+ ),
122
+ )
123
+
124
+ data = response.candidates[0].content.parts[0].inline_data.data
125
+
126
+ # Create audio directory if it doesn't exist
127
+ audio_dir = os.path.join(os.getcwd(), "audio_files")
128
+ os.makedirs(audio_dir, exist_ok=True)
129
+
130
+ # Generate unique filename
131
+ audio_filename = f"gemini_{uuid.uuid4()}.wav"
132
+ audio_path = os.path.join(audio_dir, audio_filename)
133
+
134
+ # Save audio file
135
+ save_wave_file(audio_path, data)
136
+ logger.info(f"Audio saved to {audio_path}")
137
+
138
+ # Return audio file
139
+ return FileResponse(audio_path, media_type="audio/wav", filename=audio_filename)
140
+
141
+ except Exception as e:
142
+ logger.error(f"Error in gemini_tts: {str(e)}")
143
+ raise HTTPException(
144
+ status_code=500, detail=f"Failed to generate speech: {str(e)}"
145
+ )