Spaces:
Sleeping
Sleeping
Upload 72 files
Browse files- src/agents/custom_chatbot/__pycache__/func.cpython-311.pyc +0 -0
- src/agents/custom_chatbot/func.py +3 -3
- src/agents/rag_agent_template/__pycache__/func.cpython-311.pyc +0 -0
- src/agents/rag_agent_template/__pycache__/prompt.cpython-311.pyc +0 -0
- src/agents/rag_agent_template/__pycache__/tools.cpython-311.pyc +0 -0
- src/agents/rag_agent_template/func.py +41 -23
- src/agents/rag_agent_template/prompt.py +1 -9
- src/agents/rag_agent_template/tools.py +20 -1
- src/apis/__pycache__/create_app.cpython-311.pyc +0 -0
- src/apis/create_app.py +2 -2
- src/apis/routers/__pycache__/file_processing_router.cpython-311.pyc +0 -0
- src/apis/routers/__pycache__/tts.cpython-311.pyc +0 -0
- src/apis/routers/__pycache__/tts_router.cpython-311.pyc +0 -0
- src/apis/routers/tts_router.py +145 -0
src/agents/custom_chatbot/__pycache__/func.cpython-311.pyc
CHANGED
|
Binary files a/src/agents/custom_chatbot/__pycache__/func.cpython-311.pyc and b/src/agents/custom_chatbot/__pycache__/func.cpython-311.pyc differ
|
|
|
src/agents/custom_chatbot/func.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import TypedDict,Optional
|
| 2 |
from langchain_core.messages import AnyMessage, ToolMessage
|
| 3 |
from langgraph.graph.message import add_messages
|
| 4 |
from typing import Sequence, Annotated
|
|
@@ -26,7 +26,7 @@ def get_info_collection(messages):
|
|
| 26 |
|
| 27 |
|
| 28 |
async def collection_info_agent(state: State):
|
| 29 |
-
model_name = state.get("model_name"
|
| 30 |
_, collection_info_agent = get_custom_chatbot_chains(model_name)
|
| 31 |
return await collection_info_agent.ainvoke(state)
|
| 32 |
|
|
@@ -35,7 +35,7 @@ async def create_prompt(state: State):
|
|
| 35 |
messages = state.get("messages")
|
| 36 |
name, info = get_info_collection(messages)
|
| 37 |
logger.info(f"create_prompt {info}")
|
| 38 |
-
model_name = state.get("model_name"
|
| 39 |
create_system_chain, _ = get_custom_chatbot_chains(model_name)
|
| 40 |
res = await create_system_chain.ainvoke({"info": info})
|
| 41 |
return {"prompt": res.content, "name": name}
|
|
|
|
| 1 |
+
from typing import TypedDict, Optional
|
| 2 |
from langchain_core.messages import AnyMessage, ToolMessage
|
| 3 |
from langgraph.graph.message import add_messages
|
| 4 |
from typing import Sequence, Annotated
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
async def collection_info_agent(state: State):
|
| 29 |
+
model_name = state.get("model_name")
|
| 30 |
_, collection_info_agent = get_custom_chatbot_chains(model_name)
|
| 31 |
return await collection_info_agent.ainvoke(state)
|
| 32 |
|
|
|
|
| 35 |
messages = state.get("messages")
|
| 36 |
name, info = get_info_collection(messages)
|
| 37 |
logger.info(f"create_prompt {info}")
|
| 38 |
+
model_name = state.get("model_name")
|
| 39 |
create_system_chain, _ = get_custom_chatbot_chains(model_name)
|
| 40 |
res = await create_system_chain.ainvoke({"info": info})
|
| 41 |
return {"prompt": res.content, "name": name}
|
src/agents/rag_agent_template/__pycache__/func.cpython-311.pyc
CHANGED
|
Binary files a/src/agents/rag_agent_template/__pycache__/func.cpython-311.pyc and b/src/agents/rag_agent_template/__pycache__/func.cpython-311.pyc differ
|
|
|
src/agents/rag_agent_template/__pycache__/prompt.cpython-311.pyc
CHANGED
|
Binary files a/src/agents/rag_agent_template/__pycache__/prompt.cpython-311.pyc and b/src/agents/rag_agent_template/__pycache__/prompt.cpython-311.pyc differ
|
|
|
src/agents/rag_agent_template/__pycache__/tools.cpython-311.pyc
CHANGED
|
Binary files a/src/agents/rag_agent_template/__pycache__/tools.cpython-311.pyc and b/src/agents/rag_agent_template/__pycache__/tools.cpython-311.pyc differ
|
|
|
src/agents/rag_agent_template/func.py
CHANGED
|
@@ -1,27 +1,30 @@
|
|
| 1 |
from typing import TypedDict, Optional, List
|
| 2 |
from langchain_core.messages import AnyMessage, ToolMessage
|
| 3 |
from langgraph.graph.message import add_messages
|
| 4 |
-
from .prompt import get_rag_chains
|
| 5 |
from typing import Sequence, Annotated
|
| 6 |
from langchain_core.messages import RemoveMessage
|
| 7 |
from langchain_core.documents import Document
|
| 8 |
-
from .tools import retrieve_document
|
| 9 |
from src.utils.logger import logger
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
tools = [retrieve_document]
|
| 12 |
|
| 13 |
|
| 14 |
class State(TypedDict):
|
| 15 |
messages: Annotated[Sequence[AnyMessage], add_messages]
|
| 16 |
selected_ids: Optional[List[str]]
|
| 17 |
selected_documents: Optional[List[Document]]
|
| 18 |
-
tools:
|
| 19 |
prompt: str
|
| 20 |
model_name: Optional[str]
|
| 21 |
|
| 22 |
|
| 23 |
def trim_history(state: State):
|
| 24 |
history = state.get("messages", [])
|
|
|
|
|
|
|
| 25 |
if len(history) > 10:
|
| 26 |
num_to_remove = len(history) - 10
|
| 27 |
remove_messages = [
|
|
@@ -38,7 +41,10 @@ def trim_history(state: State):
|
|
| 38 |
|
| 39 |
def execute_tool(state: State):
|
| 40 |
tool_calls = state["messages"][-1].tool_calls
|
|
|
|
| 41 |
tool_name_to_func = {tool.name: tool for tool in tools}
|
|
|
|
|
|
|
| 42 |
selected_ids = []
|
| 43 |
selected_documents = []
|
| 44 |
tool_messages = []
|
|
@@ -64,7 +70,11 @@ def execute_tool(state: State):
|
|
| 64 |
)
|
| 65 |
continue
|
| 66 |
tool_response = tool_func.invoke(tool_args)
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
return {
|
| 70 |
"selected_ids": selected_ids,
|
|
@@ -75,22 +85,30 @@ def execute_tool(state: State):
|
|
| 75 |
|
| 76 |
def generate_answer_rag(state: State):
|
| 77 |
messages = state["messages"]
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
return {"messages": response}
|
|
|
|
| 1 |
from typing import TypedDict, Optional, List
|
| 2 |
from langchain_core.messages import AnyMessage, ToolMessage
|
| 3 |
from langgraph.graph.message import add_messages
|
|
|
|
| 4 |
from typing import Sequence, Annotated
|
| 5 |
from langchain_core.messages import RemoveMessage
|
| 6 |
from langchain_core.documents import Document
|
| 7 |
+
from .tools import retrieve_document, python_repl, duckduckgo_search
|
| 8 |
from src.utils.logger import logger
|
| 9 |
+
from src.config.llm import get_llm
|
| 10 |
+
from .prompt import template_prompt
|
| 11 |
|
| 12 |
+
tools = [retrieve_document, python_repl, duckduckgo_search]
|
| 13 |
|
| 14 |
|
| 15 |
class State(TypedDict):
|
| 16 |
messages: Annotated[Sequence[AnyMessage], add_messages]
|
| 17 |
selected_ids: Optional[List[str]]
|
| 18 |
selected_documents: Optional[List[Document]]
|
| 19 |
+
tools: Optional[List[str]]
|
| 20 |
prompt: str
|
| 21 |
model_name: Optional[str]
|
| 22 |
|
| 23 |
|
| 24 |
def trim_history(state: State):
|
| 25 |
history = state.get("messages", [])
|
| 26 |
+
tool_names = state.get("tools", [])
|
| 27 |
+
|
| 28 |
if len(history) > 10:
|
| 29 |
num_to_remove = len(history) - 10
|
| 30 |
remove_messages = [
|
|
|
|
| 41 |
|
| 42 |
def execute_tool(state: State):
|
| 43 |
tool_calls = state["messages"][-1].tool_calls
|
| 44 |
+
tool_names = state.get("tools", [])
|
| 45 |
tool_name_to_func = {tool.name: tool for tool in tools}
|
| 46 |
+
tool_functions = [tool_name_to_func[name] for name in tool_names if name in tool_name_to_func]
|
| 47 |
+
|
| 48 |
selected_ids = []
|
| 49 |
selected_documents = []
|
| 50 |
tool_messages = []
|
|
|
|
| 70 |
)
|
| 71 |
continue
|
| 72 |
tool_response = tool_func.invoke(tool_args)
|
| 73 |
+
print(f"tool_response: {tool_response}")
|
| 74 |
+
tool_messages.append(ToolMessage(
|
| 75 |
+
tool_call_id=tool_id,
|
| 76 |
+
content=tool_response,
|
| 77 |
+
))
|
| 78 |
|
| 79 |
return {
|
| 80 |
"selected_ids": selected_ids,
|
|
|
|
| 85 |
|
| 86 |
def generate_answer_rag(state: State):
|
| 87 |
messages = state["messages"]
|
| 88 |
+
tool_names = state.get("tools", [])
|
| 89 |
+
prompt = state["prompt"]
|
| 90 |
+
model_name = state.get("model_name", "gemini-2.0-flash")
|
| 91 |
+
|
| 92 |
+
tool_name_to_func = {tool.name: tool for tool in tools}
|
| 93 |
+
tool_functions = [tool_name_to_func[name] for name in tool_names if name in tool_name_to_func]
|
| 94 |
+
|
| 95 |
+
print(f"tools: {tool_functions}")
|
| 96 |
+
llm_call = template_prompt | get_llm(model_name).bind_tools(tool_functions)
|
| 97 |
+
|
| 98 |
+
if tool_functions:
|
| 99 |
+
for tool in tool_functions:
|
| 100 |
+
if tool.name == "retrieve_document":
|
| 101 |
+
prompt += "Sử dụng tool `retrieve_document` để truy xuất tài liệu để bổ sung thông tin cho câu trả lời"
|
| 102 |
+
if tool.name == "python_repl":
|
| 103 |
+
prompt += "Sử dụng tool `python_repl` để thực hiện các tác vụ liên quan đến tính toán phức tạp"
|
| 104 |
+
if tool.name == "duckduckgo_search":
|
| 105 |
+
prompt += "Sử dụng tool `duckduckgo_search` để tìm kiếm thông tin trên internet"
|
| 106 |
+
|
| 107 |
+
response = llm_call.invoke(
|
| 108 |
+
{
|
| 109 |
+
"messages": messages,
|
| 110 |
+
"prompt": prompt,
|
| 111 |
+
}
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
return {"messages": response}
|
src/agents/rag_agent_template/prompt.py
CHANGED
|
@@ -2,17 +2,9 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
| 2 |
from src.config.llm import get_llm
|
| 3 |
from .tools import retrieve_document
|
| 4 |
|
| 5 |
-
|
| 6 |
[
|
| 7 |
("system", "{prompt}"),
|
| 8 |
("placeholder", "{messages}"),
|
| 9 |
]
|
| 10 |
)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def get_rag_chains(model_name: str):
|
| 14 |
-
llm = get_llm(model_name)
|
| 15 |
-
llm_rag = llm.bind_tools([retrieve_document])
|
| 16 |
-
rag_answering_chain_tool = rag_prompt | llm_rag
|
| 17 |
-
rag_answering_chain = rag_prompt | llm
|
| 18 |
-
return rag_answering_chain_tool, rag_answering_chain
|
|
|
|
| 2 |
from src.config.llm import get_llm
|
| 3 |
from .tools import retrieve_document
|
| 4 |
|
| 5 |
+
template_prompt = ChatPromptTemplate.from_messages(
|
| 6 |
[
|
| 7 |
("system", "{prompt}"),
|
| 8 |
("placeholder", "{messages}"),
|
| 9 |
]
|
| 10 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/agents/rag_agent_template/tools.py
CHANGED
|
@@ -3,12 +3,18 @@ from src.config.vector_store import test_rag_vector_store
|
|
| 3 |
from src.utils.helper import convert_list_context_source_to_str
|
| 4 |
from src.utils.logger import logger
|
| 5 |
from langchain_core.runnables import RunnableConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
@tool
|
| 9 |
def retrieve_document(query: str, config: RunnableConfig):
|
| 10 |
"""Ưu tiên truy xuất tài liệu từ vector store nếu câu hỏi liên quan đến vai trò của chatbot.
|
| 11 |
-
|
| 12 |
|
| 13 |
Args:
|
| 14 |
query (str): Câu truy vấn của người dùng bằng tiếng Việt
|
|
@@ -34,3 +40,16 @@ def retrieve_document(query: str, config: RunnableConfig):
|
|
| 34 |
"selected_documents": selected_documents,
|
| 35 |
"selected_ids": selected_ids,
|
| 36 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from src.utils.helper import convert_list_context_source_to_str
|
| 4 |
from src.utils.logger import logger
|
| 5 |
from langchain_core.runnables import RunnableConfig
|
| 6 |
+
from langchain_experimental.utilities import PythonREPL
|
| 7 |
+
from langchain_community.tools import DuckDuckGoSearchRun
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
duckduckgo_search = DuckDuckGoSearchRun()
|
| 11 |
+
|
| 12 |
+
python_exec = PythonREPL()
|
| 13 |
|
| 14 |
|
| 15 |
@tool
|
| 16 |
def retrieve_document(query: str, config: RunnableConfig):
|
| 17 |
"""Ưu tiên truy xuất tài liệu từ vector store nếu câu hỏi liên quan đến vai trò của chatbot.
|
|
|
|
| 18 |
|
| 19 |
Args:
|
| 20 |
query (str): Câu truy vấn của người dùng bằng tiếng Việt
|
|
|
|
| 40 |
"selected_documents": selected_documents,
|
| 41 |
"selected_ids": selected_ids,
|
| 42 |
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@tool
|
| 46 |
+
def python_repl(code: str):
|
| 47 |
+
"""
|
| 48 |
+
A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
code (str): Python code to execute
|
| 52 |
+
Returns:
|
| 53 |
+
str: Output of the Python code
|
| 54 |
+
"""
|
| 55 |
+
return python_exec.run(code)
|
src/apis/__pycache__/create_app.cpython-311.pyc
CHANGED
|
Binary files a/src/apis/__pycache__/create_app.cpython-311.pyc and b/src/apis/__pycache__/create_app.cpython-311.pyc differ
|
|
|
src/apis/create_app.py
CHANGED
|
@@ -4,13 +4,13 @@ from src.apis.routers.rag_agent_template import router as router_rag_agent_templ
|
|
| 4 |
from src.apis.routers.file_processing_router import router as router_file_processing
|
| 5 |
from src.apis.routers.custom_chatbot_router import router as custom_chatbot_processing
|
| 6 |
from src.apis.routers.vector_store_router import router as vector_store_router
|
| 7 |
-
|
| 8 |
api_router = APIRouter()
|
| 9 |
api_router.include_router(router_rag_agent_template)
|
| 10 |
api_router.include_router(router_file_processing)
|
| 11 |
api_router.include_router(custom_chatbot_processing)
|
| 12 |
api_router.include_router(vector_store_router)
|
| 13 |
-
|
| 14 |
def create_app():
|
| 15 |
app = FastAPI(
|
| 16 |
docs_url="/",
|
|
|
|
| 4 |
from src.apis.routers.file_processing_router import router as router_file_processing
|
| 5 |
from src.apis.routers.custom_chatbot_router import router as custom_chatbot_processing
|
| 6 |
from src.apis.routers.vector_store_router import router as vector_store_router
|
| 7 |
+
from src.apis.routers.tts_router import router as tts_router
|
| 8 |
api_router = APIRouter()
|
| 9 |
api_router.include_router(router_rag_agent_template)
|
| 10 |
api_router.include_router(router_file_processing)
|
| 11 |
api_router.include_router(custom_chatbot_processing)
|
| 12 |
api_router.include_router(vector_store_router)
|
| 13 |
+
api_router.include_router(tts_router)
|
| 14 |
def create_app():
|
| 15 |
app = FastAPI(
|
| 16 |
docs_url="/",
|
src/apis/routers/__pycache__/file_processing_router.cpython-311.pyc
CHANGED
|
Binary files a/src/apis/routers/__pycache__/file_processing_router.cpython-311.pyc and b/src/apis/routers/__pycache__/file_processing_router.cpython-311.pyc differ
|
|
|
src/apis/routers/__pycache__/tts.cpython-311.pyc
ADDED
|
Binary file (2.29 kB). View file
|
|
|
src/apis/routers/__pycache__/tts_router.cpython-311.pyc
ADDED
|
Binary file (8.23 kB). View file
|
|
|
src/apis/routers/tts_router.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from transformers import VitsModel, AutoTokenizer
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import uuid
|
| 7 |
+
from fastapi.responses import FileResponse
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
import numpy as np
|
| 10 |
+
from src.utils.logger import logger
|
| 11 |
+
from google import genai
|
| 12 |
+
from google.genai import types
|
| 13 |
+
import wave
|
| 14 |
+
from typing import Literal
|
| 15 |
+
|
| 16 |
+
router = APIRouter()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TTSRequest(BaseModel):
|
| 20 |
+
text: str
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class GeminiTTSRequest(BaseModel):
|
| 24 |
+
text: str
|
| 25 |
+
voice_name: str = "Kore"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Initialize model and tokenizer globally
|
| 29 |
+
try:
|
| 30 |
+
logger.info("Loading TTS model and tokenizer...")
|
| 31 |
+
model = VitsModel.from_pretrained("facebook/mms-tts-vie")
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-vie")
|
| 33 |
+
logger.info("TTS model and tokenizer loaded successfully")
|
| 34 |
+
except Exception as e:
|
| 35 |
+
logger.error(f"Failed to load TTS model: {str(e)}")
|
| 36 |
+
raise
|
| 37 |
+
|
| 38 |
+
# Initialize Google Gemini client
|
| 39 |
+
try:
|
| 40 |
+
logger.info("Initializing Google Gemini client...")
|
| 41 |
+
gemini_client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
|
| 42 |
+
logger.info("Google Gemini client initialized successfully")
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.error(f"Failed to initialize Google Gemini client: {str(e)}")
|
| 45 |
+
raise
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def save_wave_file(
|
| 49 |
+
filename: str,
|
| 50 |
+
pcm: bytes,
|
| 51 |
+
channels: int = 1,
|
| 52 |
+
rate: int = 24000,
|
| 53 |
+
sample_width: int = 2,
|
| 54 |
+
):
|
| 55 |
+
"""Save PCM data to a WAV file."""
|
| 56 |
+
with wave.open(filename, "wb") as wf:
|
| 57 |
+
wf.setnchannels(channels)
|
| 58 |
+
wf.setsampwidth(sample_width)
|
| 59 |
+
wf.setframerate(rate)
|
| 60 |
+
wf.writeframes(pcm)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@router.post("/tts/huggingface")
|
| 64 |
+
async def huggingface_tts(request: TTSRequest):
|
| 65 |
+
try:
|
| 66 |
+
logger.info(
|
| 67 |
+
f"Processing HuggingFace TTS request for text: {request.text[:50]}..."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Tokenize input
|
| 71 |
+
inputs = tokenizer(request.text, return_tensors="pt")
|
| 72 |
+
logger.info("Text tokenized successfully")
|
| 73 |
+
|
| 74 |
+
# Generate audio
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
output = model(**inputs).waveform
|
| 77 |
+
logger.info("Audio generated successfully")
|
| 78 |
+
|
| 79 |
+
# Convert tensor to numpy array
|
| 80 |
+
audio_numpy = output.squeeze().cpu().numpy()
|
| 81 |
+
|
| 82 |
+
# Create audio directory if it doesn't exist
|
| 83 |
+
audio_dir = os.path.join(os.getcwd(), "audio_files")
|
| 84 |
+
os.makedirs(audio_dir, exist_ok=True)
|
| 85 |
+
|
| 86 |
+
# Generate unique filename
|
| 87 |
+
audio_filename = f"huggingface_{uuid.uuid4()}.wav"
|
| 88 |
+
audio_path = os.path.join(audio_dir, audio_filename)
|
| 89 |
+
|
| 90 |
+
# Save audio file using soundfile
|
| 91 |
+
sf.write(audio_path, audio_numpy, model.config.sampling_rate)
|
| 92 |
+
logger.info(f"Audio saved to {audio_path}")
|
| 93 |
+
|
| 94 |
+
# Return audio file
|
| 95 |
+
return FileResponse(audio_path, media_type="audio/wav", filename=audio_filename)
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.error(f"Error in huggingface_tts: {str(e)}")
|
| 99 |
+
raise HTTPException(
|
| 100 |
+
status_code=500, detail=f"Failed to generate speech: {str(e)}"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@router.post("/tts/gemini")
|
| 105 |
+
def gemini_tts(request: GeminiTTSRequest):
|
| 106 |
+
try:
|
| 107 |
+
logger.info(f"Processing Gemini TTS request for text: {request.text[:50]}...")
|
| 108 |
+
|
| 109 |
+
response = gemini_client.models.generate_content(
|
| 110 |
+
model="gemini-2.5-flash-preview-tts",
|
| 111 |
+
contents=request.text,
|
| 112 |
+
config=types.GenerateContentConfig(
|
| 113 |
+
response_modalities=["AUDIO"],
|
| 114 |
+
speech_config=types.SpeechConfig(
|
| 115 |
+
voice_config=types.VoiceConfig(
|
| 116 |
+
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
| 117 |
+
voice_name=request.voice_name,
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
),
|
| 121 |
+
),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
data = response.candidates[0].content.parts[0].inline_data.data
|
| 125 |
+
|
| 126 |
+
# Create audio directory if it doesn't exist
|
| 127 |
+
audio_dir = os.path.join(os.getcwd(), "audio_files")
|
| 128 |
+
os.makedirs(audio_dir, exist_ok=True)
|
| 129 |
+
|
| 130 |
+
# Generate unique filename
|
| 131 |
+
audio_filename = f"gemini_{uuid.uuid4()}.wav"
|
| 132 |
+
audio_path = os.path.join(audio_dir, audio_filename)
|
| 133 |
+
|
| 134 |
+
# Save audio file
|
| 135 |
+
save_wave_file(audio_path, data)
|
| 136 |
+
logger.info(f"Audio saved to {audio_path}")
|
| 137 |
+
|
| 138 |
+
# Return audio file
|
| 139 |
+
return FileResponse(audio_path, media_type="audio/wav", filename=audio_filename)
|
| 140 |
+
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.error(f"Error in gemini_tts: {str(e)}")
|
| 143 |
+
raise HTTPException(
|
| 144 |
+
status_code=500, detail=f"Failed to generate speech: {str(e)}"
|
| 145 |
+
)
|