File size: 8,680 Bytes
5669b22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import uuid
import os
import shutil
import json
from loguru import logger

from .agent.stateless_llm_factory import LLMFactory
from .asr.asr_factory import ASRFactory
from .tts.tts_factory import TTSFactory
from .translate.translate_factory import TranslateFactory
from prompts import prompt_loader

from .live2d_model import Live2dModel
from .audio_manager import AudioManager
from .conversation_manager import ConversationManager
from .interrupt_manager import InterruptManager

class OpenLLMVTuberMain:

    def __init__(

        self,

        configs: dict,

        custom_asr=None,

        custom_tts=None,

        websocket=None,

        loop=None,

    ) -> None:
        logger.info("t41372/Open-LLM-VTuber, version 1.0.0")
        
        self.config = configs
        self.verbose = self.config.get("VERBOSE", False)
        self.websocket = websocket
        self.live2d = self.init_live2d()
        self.session_id = str(uuid.uuid4().hex)
        self.loop = loop

        # ASR
        if self.config.get("VOICE_INPUT_ON", False):
            if custom_asr is None:
                self.asr = self.init_asr()
            else:
                print("Using custom ASR")
                self.asr = custom_asr
        else:
            self.asr = None

        # TTS
        if self.config.get("TTS_ON", False):
            if custom_tts is None:
                self.tts = self.init_tts()
            else:
                print("Using custom TTS")
                self.tts = custom_tts
        else:
            self.tts = None

        # Translator
        if self.config.get("TRANSLATE_AUDIO", False):
            try:
                translate_provider = self.config.get("TRANSLATE_PROVIDER", "DeepLX")
                self.translator = TranslateFactory.get_translator(
                    translate_provider=translate_provider,
                    **self.config.get(translate_provider, {}),
                )
            except Exception as e:
                print(f"Error initializing Translator: {e}")
                print("Proceed without Translator.")
                self.translator = None
        else:
            self.translator = None

        self.llm = self.init_llm()

        self.audio_manager = AudioManager(self.tts, self.live2d, self.translator, self.config, self.verbose)
        
        self.interrupt_manager = InterruptManager(self.llm)
        
        self.claude_api_key = self.config.get("CLAUDE_API_KEY", None)

        self.conversation_manager = ConversationManager(
            self.config, self.llm, self.asr, self.tts, self.live2d, self.translator, self.audio_manager, self.interrupt_manager, self.claude_api_key, self.verbose, self.loop
        )
        
        if "REMOVE_SPECIAL_CHAR" not in self.config:
            self.config["REMOVE_SPECIAL_CHAR"] = True
            
    def init_live2d(self):
        if not self.config.get("LIVE2D", False):
            return None
        try:
            live2d_model_name = self.config.get("LIVE2D_MODEL")
            live2d_controller = Live2dModel(live2d_model_name)
        except Exception as e:
            print(f"Error initializing Live2D: {e}")
            print("Proceed without Live2D.")
            return None
        return live2d_controller

    def init_llm(self):
        import yaml
        import os

        # 1. Đọc trực tiếp file conf.yaml
        try:
            with open("conf.yaml", "r", encoding="utf-8") as f:
                raw_config = yaml.safe_load(f)
        except Exception as e:
            logger.error(f"Lỗi đọc conf.yaml: {e}")
            raw_config = {}

        # 2. Truy xuất dữ liệu theo phân cấp
        char_cfg = raw_config.get("character_config", {})
        agent_cfg = char_cfg.get("agent_config", {})
        agent_settings = agent_cfg.get("agent_settings", {})
        
        llm_provider = agent_settings.get("basic_memory_agent", {}).get("llm_provider") or "openai_llm"
        llm_configs_pool = agent_cfg.get("llm_configs", {})
        llm_config = llm_configs_pool.get(llm_provider, {})

        # 3. Lấy Key, URL và Model
        api_key = llm_config.get("llm_api_key", "")
        base_url = llm_config.get("base_url", "https://openrouter.ai/api/v1")
        model = llm_config.get("model", "qwen/qwen3.6-plus:free")

        # 4. Thiết lập biến môi trường và lọc tham số phụ
        if api_key:
            os.environ["OPENAI_API_KEY"] = api_key
            logger.info("✅ Đã BƠM TRỰC TIẾP API Key vào môi trường hệ thống!")
        
        # Lọc bỏ các tham số chính để tạo extra_kwargs
        extra_kwargs = {k: v for k, v in llm_config.items() if k not in ["llm_api_key", "base_url", "model"]}

        # 5. Lấy prompt và tools
        system_prompt, tools = self.get_system_prompt_and_tools()
        logger.info(f"Khởi tạo LLM: {llm_provider} | Model: {model} | Số lượng Tool: {len(tools)}")

        # 6. Khởi tạo LLM qua Factory
        llm = LLMFactory.create_llm(
            llm_provider=llm_provider, 
            SYSTEM_PROMPT=system_prompt,
            tools=tools, 
            caller=None,
            api_key=api_key,   # Chỗ này dùng api_key cho đúng chuẩn Factory
            base_url=base_url,
            model=model,
            **extra_kwargs
        )
        return llm

    def init_asr(self):
        asr_model = self.config.get("ASR_MODEL")
        asr_config = self.config.get(asr_model, {})
        asr = ASRFactory.get_asr_system(asr_model, **asr_config)
        return asr

    def init_tts(self):
        tts_model = self.config.get("TTS_MODEL", "pyttsx3TTS")
        tts_config = self.config.get(tts_model, {})
        return TTSFactory.get_tts_engine(tts_model, **tts_config)
    
    def get_song_list(self) -> list[str]:
        song_file_path = "./sing/original"
        if not os.path.exists(song_file_path):
            os.makedirs(song_file_path, exist_ok=True)
            return []
        song_list = os.listdir(song_file_path)
        return [os.path.splitext(song)[0] for song in song_list]
    
    def get_system_prompt_and_tools(self) -> tuple[str, list[dict]]:
        # 1. Lấy persona_prompt từ yaml của bạn (giữ nguyên logic cũ)
        system_prompt = self.config.get("persona_prompt") or ""
        
        # 2. Thêm chỉ dẫn công cụ vào prompt
        try:
            system_prompt += prompt_loader.load_util("tools_prompt").replace("[<insert_song_list>]", str(self.get_song_list()))
        except:
            pass

        # 3. ĐỌC TRỰC TIẾP FILE tools.json CỦA BẠN
        try:
            # Đường dẫn tương đối từ thư mục chạy project
            tools_path = os.path.join("prompts", "utils", "tools.json")
            with open(tools_path, 'r', encoding='utf-8') as f:
                tools = json.load(f)
        except Exception as e:
            logger.error(f"Không thể đọc file tools.json: {e}")
            # Dự phòng một tool trống nếu đọc file lỗi
            tools = []

        # 4. Cập nhật danh sách nhạc vào tool (nếu tool có chức năng play_music)
        if tools and len(tools) > 0:
            try:
                # Cập nhật enum cho bài hát để AI biết có bài gì mà chọn
                tools[0]["function"]["parameters"]["properties"]["song_name"]["enum"] = self.get_song_list()
            except:
                pass
        
        if self.verbose:
            print("\n === System Prompt ===")
            print(system_prompt)

        return system_prompt, tools
    def set_audio_output_func(

        self, audio_output_func

    ) -> None:
        self.audio_manager.play_audio_file = audio_output_func

    def clean_cache(self):
        cache_dir = "./cache"
        if os.path.exists(cache_dir):
            shutil.rmtree(cache_dir)
            os.makedirs(cache_dir)

    async def conversation_chain(self, user_input=None, clipboard_data=None):
        # Dùng 'async for' để nhận từng gói payload từ conversation_manager
        async for payload in self.conversation_manager.conversation_chain(
            user_input=user_input, 
            clipboard_data=clipboard_data
        ):
            # Chuyển tiếp (yield) lên cho websocket_handler
            yield payload

    def interrupt(self, heard_sentence: str = "") -> None:
        self.interrupt_manager.interrupt(heard_sentence)