F5-TTS-THAI / deployment /src /f5_tts /model_manager.py
pythonlearnreal's picture
Upload folder using huggingface_hub
106478e verified
"""
Model Manager for F5-TTS Thai
จัดการการโหลดและเปลี่ยนโมเดล F5-TTS
"""
import os
import torch
from cached_path import cached_path
from f5_tts.infer.utils_infer import load_model, load_vocoder
from f5_tts.model import DiT
from f5_tts.config import (
DEFAULT_MODEL_BASE,
FP16_MODEL_BASE,
VOCAB_BASE,
VOCAB_HF,
F5TTS_MODEL_CFG,
MODEL_CHOICES
)
class ModelManager:
"""จัดการการโหลดและเปลี่ยนโมเดล F5-TTS"""
def __init__(self):
self.f5tts_model = None
self.vocoder = None
self.current_model_path = None
self._initialize()
def _initialize(self):
"""เริ่มต้นโหลดโมเดลเริ่มต้น"""
self.vocoder = load_vocoder()
self.load_default_model()
def load_default_model(self):
"""โหลดโมเดลเริ่มต้น"""
self.f5tts_model = self._load_f5tts_model(str(cached_path(DEFAULT_MODEL_BASE)))
self.current_model_path = DEFAULT_MODEL_BASE
print(f"โหลดโมเดลเริ่มต้น: {DEFAULT_MODEL_BASE}")
def _load_f5tts_model(self, ckpt_path, vocab_path=VOCAB_BASE):
"""โหลดโมเดล F5-TTS"""
vocab_file = vocab_path if os.path.exists(VOCAB_BASE) else str(cached_path(VOCAB_HF))
model = load_model(
DiT,
F5TTS_MODEL_CFG,
ckpt_path,
vocab_file=vocab_file,
use_ema=True
)
print(f"โหลดโมเดลจาก {ckpt_path}")
return model
def load_model_by_choice(self, model_choice, custom_path=None):
"""โหลดโมเดลตามตัวเลือก"""
torch.cuda.empty_cache()
try:
if model_choice == "Custom":
if not custom_path:
raise ValueError("กรุณาระบุตำแหน่งโมเดลแบบกำหนดเอง")
self.f5tts_model = self._load_f5tts_model(str(cached_path(custom_path)))
self.current_model_path = custom_path
return f"โหลดโมเดลแบบกำหนดเอง: {custom_path}"
elif model_choice == "FP16":
self.f5tts_model = self._load_f5tts_model(str(cached_path(FP16_MODEL_BASE)))
self.current_model_path = FP16_MODEL_BASE
return f"โหลดโมเดล FP16: {FP16_MODEL_BASE}"
else: # Default
self.f5tts_model = self._load_f5tts_model(str(cached_path(DEFAULT_MODEL_BASE)))
self.current_model_path = DEFAULT_MODEL_BASE
return f"โหลดโมเดลเริ่มต้น: {DEFAULT_MODEL_BASE}"
except Exception as e:
error_msg = f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}"
print(error_msg)
return error_msg
def get_model(self):
"""ดึงโมเดล F5-TTS ปัจจุบัน"""
if self.f5tts_model is None:
self.load_default_model()
return self.f5tts_model
def get_vocoder(self):
"""ดึง vocoder"""
return self.vocoder
def get_current_model_info(self):
"""ดึงข้อมูลโมเดลปัจจุบัน"""
return {
"model_path": self.current_model_path,
"is_loaded": self.f5tts_model is not None
}
def update_custom_model_visibility(self, selected_model):
"""อัปเดตการแสดงผลของ custom model input"""
return selected_model == "Custom"