test_ui / src /open_llm_vtuber /openllm_vtuber_main.py
britto224's picture
Upload 130 files
5669b22 verified
import uuid
import os
import shutil
import json
from loguru import logger
from .agent.stateless_llm_factory import LLMFactory
from .asr.asr_factory import ASRFactory
from .tts.tts_factory import TTSFactory
from .translate.translate_factory import TranslateFactory
from prompts import prompt_loader
from .live2d_model import Live2dModel
from .audio_manager import AudioManager
from .conversation_manager import ConversationManager
from .interrupt_manager import InterruptManager
class OpenLLMVTuberMain:
def __init__(
self,
configs: dict,
custom_asr=None,
custom_tts=None,
websocket=None,
loop=None,
) -> None:
logger.info("t41372/Open-LLM-VTuber, version 1.0.0")
self.config = configs
self.verbose = self.config.get("VERBOSE", False)
self.websocket = websocket
self.live2d = self.init_live2d()
self.session_id = str(uuid.uuid4().hex)
self.loop = loop
# ASR
if self.config.get("VOICE_INPUT_ON", False):
if custom_asr is None:
self.asr = self.init_asr()
else:
print("Using custom ASR")
self.asr = custom_asr
else:
self.asr = None
# TTS
if self.config.get("TTS_ON", False):
if custom_tts is None:
self.tts = self.init_tts()
else:
print("Using custom TTS")
self.tts = custom_tts
else:
self.tts = None
# Translator
if self.config.get("TRANSLATE_AUDIO", False):
try:
translate_provider = self.config.get("TRANSLATE_PROVIDER", "DeepLX")
self.translator = TranslateFactory.get_translator(
translate_provider=translate_provider,
**self.config.get(translate_provider, {}),
)
except Exception as e:
print(f"Error initializing Translator: {e}")
print("Proceed without Translator.")
self.translator = None
else:
self.translator = None
self.llm = self.init_llm()
self.audio_manager = AudioManager(self.tts, self.live2d, self.translator, self.config, self.verbose)
self.interrupt_manager = InterruptManager(self.llm)
self.claude_api_key = self.config.get("CLAUDE_API_KEY", None)
self.conversation_manager = ConversationManager(
self.config, self.llm, self.asr, self.tts, self.live2d, self.translator, self.audio_manager, self.interrupt_manager, self.claude_api_key, self.verbose, self.loop
)
if "REMOVE_SPECIAL_CHAR" not in self.config:
self.config["REMOVE_SPECIAL_CHAR"] = True
def init_live2d(self):
if not self.config.get("LIVE2D", False):
return None
try:
live2d_model_name = self.config.get("LIVE2D_MODEL")
live2d_controller = Live2dModel(live2d_model_name)
except Exception as e:
print(f"Error initializing Live2D: {e}")
print("Proceed without Live2D.")
return None
return live2d_controller
def init_llm(self):
import yaml
import os
# 1. Đọc trực tiếp file conf.yaml
try:
with open("conf.yaml", "r", encoding="utf-8") as f:
raw_config = yaml.safe_load(f)
except Exception as e:
logger.error(f"Lỗi đọc conf.yaml: {e}")
raw_config = {}
# 2. Truy xuất dữ liệu theo phân cấp
char_cfg = raw_config.get("character_config", {})
agent_cfg = char_cfg.get("agent_config", {})
agent_settings = agent_cfg.get("agent_settings", {})
llm_provider = agent_settings.get("basic_memory_agent", {}).get("llm_provider") or "openai_llm"
llm_configs_pool = agent_cfg.get("llm_configs", {})
llm_config = llm_configs_pool.get(llm_provider, {})
# 3. Lấy Key, URL và Model
api_key = llm_config.get("llm_api_key", "")
base_url = llm_config.get("base_url", "https://openrouter.ai/api/v1")
model = llm_config.get("model", "qwen/qwen3.6-plus:free")
# 4. Thiết lập biến môi trường và lọc tham số phụ
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
logger.info("✅ Đã BƠM TRỰC TIẾP API Key vào môi trường hệ thống!")
# Lọc bỏ các tham số chính để tạo extra_kwargs
extra_kwargs = {k: v for k, v in llm_config.items() if k not in ["llm_api_key", "base_url", "model"]}
# 5. Lấy prompt và tools
system_prompt, tools = self.get_system_prompt_and_tools()
logger.info(f"Khởi tạo LLM: {llm_provider} | Model: {model} | Số lượng Tool: {len(tools)}")
# 6. Khởi tạo LLM qua Factory
llm = LLMFactory.create_llm(
llm_provider=llm_provider,
SYSTEM_PROMPT=system_prompt,
tools=tools,
caller=None,
api_key=api_key, # Chỗ này dùng api_key cho đúng chuẩn Factory
base_url=base_url,
model=model,
**extra_kwargs
)
return llm
def init_asr(self):
asr_model = self.config.get("ASR_MODEL")
asr_config = self.config.get(asr_model, {})
asr = ASRFactory.get_asr_system(asr_model, **asr_config)
return asr
def init_tts(self):
tts_model = self.config.get("TTS_MODEL", "pyttsx3TTS")
tts_config = self.config.get(tts_model, {})
return TTSFactory.get_tts_engine(tts_model, **tts_config)
def get_song_list(self) -> list[str]:
song_file_path = "./sing/original"
if not os.path.exists(song_file_path):
os.makedirs(song_file_path, exist_ok=True)
return []
song_list = os.listdir(song_file_path)
return [os.path.splitext(song)[0] for song in song_list]
def get_system_prompt_and_tools(self) -> tuple[str, list[dict]]:
# 1. Lấy persona_prompt từ yaml của bạn (giữ nguyên logic cũ)
system_prompt = self.config.get("persona_prompt") or ""
# 2. Thêm chỉ dẫn công cụ vào prompt
try:
system_prompt += prompt_loader.load_util("tools_prompt").replace("[<insert_song_list>]", str(self.get_song_list()))
except:
pass
# 3. ĐỌC TRỰC TIẾP FILE tools.json CỦA BẠN
try:
# Đường dẫn tương đối từ thư mục chạy project
tools_path = os.path.join("prompts", "utils", "tools.json")
with open(tools_path, 'r', encoding='utf-8') as f:
tools = json.load(f)
except Exception as e:
logger.error(f"Không thể đọc file tools.json: {e}")
# Dự phòng một tool trống nếu đọc file lỗi
tools = []
# 4. Cập nhật danh sách nhạc vào tool (nếu tool có chức năng play_music)
if tools and len(tools) > 0:
try:
# Cập nhật enum cho bài hát để AI biết có bài gì mà chọn
tools[0]["function"]["parameters"]["properties"]["song_name"]["enum"] = self.get_song_list()
except:
pass
if self.verbose:
print("\n === System Prompt ===")
print(system_prompt)
return system_prompt, tools
def set_audio_output_func(
self, audio_output_func
) -> None:
self.audio_manager.play_audio_file = audio_output_func
def clean_cache(self):
cache_dir = "./cache"
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
os.makedirs(cache_dir)
async def conversation_chain(self, user_input=None, clipboard_data=None):
# Dùng 'async for' để nhận từng gói payload từ conversation_manager
async for payload in self.conversation_manager.conversation_chain(
user_input=user_input,
clipboard_data=clipboard_data
):
# Chuyển tiếp (yield) lên cho websocket_handler
yield payload
def interrupt(self, heard_sentence: str = "") -> None:
self.interrupt_manager.interrupt(heard_sentence)