""" Model Manager for F5-TTS Thai จัดการการโหลดและเปลี่ยนโมเดล F5-TTS """ import os import torch from cached_path import cached_path from f5_tts.infer.utils_infer import load_model, load_vocoder from f5_tts.model import DiT from f5_tts.config import ( DEFAULT_MODEL_BASE, FP16_MODEL_BASE, VOCAB_BASE, VOCAB_HF, F5TTS_MODEL_CFG, MODEL_CHOICES ) class ModelManager: """จัดการการโหลดและเปลี่ยนโมเดล F5-TTS""" def __init__(self): self.f5tts_model = None self.vocoder = None self.current_model_path = None self._initialize() def _initialize(self): """เริ่มต้นโหลดโมเดลเริ่มต้น""" self.vocoder = load_vocoder() self.load_default_model() def load_default_model(self): """โหลดโมเดลเริ่มต้น""" self.f5tts_model = self._load_f5tts_model(str(cached_path(DEFAULT_MODEL_BASE))) self.current_model_path = DEFAULT_MODEL_BASE print(f"โหลดโมเดลเริ่มต้น: {DEFAULT_MODEL_BASE}") def _load_f5tts_model(self, ckpt_path, vocab_path=VOCAB_BASE): """โหลดโมเดล F5-TTS""" vocab_file = vocab_path if os.path.exists(VOCAB_BASE) else str(cached_path(VOCAB_HF)) model = load_model( DiT, F5TTS_MODEL_CFG, ckpt_path, vocab_file=vocab_file, use_ema=True ) print(f"โหลดโมเดลจาก {ckpt_path}") return model def load_model_by_choice(self, model_choice, custom_path=None): """โหลดโมเดลตามตัวเลือก""" torch.cuda.empty_cache() try: if model_choice == "Custom": if not custom_path: raise ValueError("กรุณาระบุตำแหน่งโมเดลแบบกำหนดเอง") self.f5tts_model = self._load_f5tts_model(str(cached_path(custom_path))) self.current_model_path = custom_path return f"โหลดโมเดลแบบกำหนดเอง: {custom_path}" elif model_choice == "FP16": self.f5tts_model = self._load_f5tts_model(str(cached_path(FP16_MODEL_BASE))) self.current_model_path = FP16_MODEL_BASE return f"โหลดโมเดล FP16: {FP16_MODEL_BASE}" else: # Default self.f5tts_model = self._load_f5tts_model(str(cached_path(DEFAULT_MODEL_BASE))) self.current_model_path = DEFAULT_MODEL_BASE return f"โหลดโมเดลเริ่มต้น: {DEFAULT_MODEL_BASE}" except Exception as e: error_msg = f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}" print(error_msg) return error_msg def get_model(self): """ดึงโมเดล F5-TTS ปัจจุบัน""" if self.f5tts_model is None: self.load_default_model() return self.f5tts_model def get_vocoder(self): """ดึง vocoder""" return self.vocoder def get_current_model_info(self): """ดึงข้อมูลโมเดลปัจจุบัน""" return { "model_path": self.current_model_path, "is_loaded": self.f5tts_model is not None } def update_custom_model_visibility(self, selected_model): """อัปเดตการแสดงผลของ custom model input""" return selected_model == "Custom"