Spaces:
Sleeping
Sleeping
| """ | |
| Model Manager for F5-TTS Thai | |
| จัดการการโหลดและเปลี่ยนโมเดล F5-TTS | |
| """ | |
| import os | |
| import torch | |
| from cached_path import cached_path | |
| from f5_tts.infer.utils_infer import load_model, load_vocoder | |
| from f5_tts.model import DiT | |
| from f5_tts.config import ( | |
| DEFAULT_MODEL_BASE, | |
| FP16_MODEL_BASE, | |
| VOCAB_BASE, | |
| VOCAB_HF, | |
| F5TTS_MODEL_CFG, | |
| MODEL_CHOICES | |
| ) | |
| class ModelManager: | |
| """จัดการการโหลดและเปลี่ยนโมเดล F5-TTS""" | |
| def __init__(self): | |
| self.f5tts_model = None | |
| self.vocoder = None | |
| self.current_model_path = None | |
| self._initialize() | |
| def _initialize(self): | |
| """เริ่มต้นโหลดโมเดลเริ่มต้น""" | |
| self.vocoder = load_vocoder() | |
| self.load_default_model() | |
| def load_default_model(self): | |
| """โหลดโมเดลเริ่มต้น""" | |
| self.f5tts_model = self._load_f5tts_model(str(cached_path(DEFAULT_MODEL_BASE))) | |
| self.current_model_path = DEFAULT_MODEL_BASE | |
| print(f"โหลดโมเดลเริ่มต้น: {DEFAULT_MODEL_BASE}") | |
| def _load_f5tts_model(self, ckpt_path, vocab_path=VOCAB_BASE): | |
| """โหลดโมเดล F5-TTS""" | |
| vocab_file = vocab_path if os.path.exists(VOCAB_BASE) else str(cached_path(VOCAB_HF)) | |
| model = load_model( | |
| DiT, | |
| F5TTS_MODEL_CFG, | |
| ckpt_path, | |
| vocab_file=vocab_file, | |
| use_ema=True | |
| ) | |
| print(f"โหลดโมเดลจาก {ckpt_path}") | |
| return model | |
| def load_model_by_choice(self, model_choice, custom_path=None): | |
| """โหลดโมเดลตามตัวเลือก""" | |
| torch.cuda.empty_cache() | |
| try: | |
| if model_choice == "Custom": | |
| if not custom_path: | |
| raise ValueError("กรุณาระบุตำแหน่งโมเดลแบบกำหนดเอง") | |
| self.f5tts_model = self._load_f5tts_model(str(cached_path(custom_path))) | |
| self.current_model_path = custom_path | |
| return f"โหลดโมเดลแบบกำหนดเอง: {custom_path}" | |
| elif model_choice == "FP16": | |
| self.f5tts_model = self._load_f5tts_model(str(cached_path(FP16_MODEL_BASE))) | |
| self.current_model_path = FP16_MODEL_BASE | |
| return f"โหลดโมเดล FP16: {FP16_MODEL_BASE}" | |
| else: # Default | |
| self.f5tts_model = self._load_f5tts_model(str(cached_path(DEFAULT_MODEL_BASE))) | |
| self.current_model_path = DEFAULT_MODEL_BASE | |
| return f"โหลดโมเดลเริ่มต้น: {DEFAULT_MODEL_BASE}" | |
| except Exception as e: | |
| error_msg = f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| def get_model(self): | |
| """ดึงโมเดล F5-TTS ปัจจุบัน""" | |
| if self.f5tts_model is None: | |
| self.load_default_model() | |
| return self.f5tts_model | |
| def get_vocoder(self): | |
| """ดึง vocoder""" | |
| return self.vocoder | |
| def get_current_model_info(self): | |
| """ดึงข้อมูลโมเดลปัจจุบัน""" | |
| return { | |
| "model_path": self.current_model_path, | |
| "is_loaded": self.f5tts_model is not None | |
| } | |
| def update_custom_model_visibility(self, selected_model): | |
| """อัปเดตการแสดงผลของ custom model input""" | |
| return selected_model == "Custom" |