File size: 3,940 Bytes
106478e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Model Manager for F5-TTS Thai
จัดการการโหลดและเปลี่ยนโมเดล F5-TTS
"""

import os
import torch
from cached_path import cached_path

from f5_tts.infer.utils_infer import load_model, load_vocoder
from f5_tts.model import DiT
from f5_tts.config import (
    DEFAULT_MODEL_BASE, 
    FP16_MODEL_BASE, 
    VOCAB_BASE, 
    VOCAB_HF, 
    F5TTS_MODEL_CFG,
    MODEL_CHOICES
)


class ModelManager:
    """จัดการการโหลดและเปลี่ยนโมเดล F5-TTS"""
    
    def __init__(self):
        self.f5tts_model = None
        self.vocoder = None
        self.current_model_path = None
        self._initialize()
    
    def _initialize(self):
        """เริ่มต้นโหลดโมเดลเริ่มต้น"""
        self.vocoder = load_vocoder()
        self.load_default_model()
    
    def load_default_model(self):
        """โหลดโมเดลเริ่มต้น"""
        self.f5tts_model = self._load_f5tts_model(str(cached_path(DEFAULT_MODEL_BASE)))
        self.current_model_path = DEFAULT_MODEL_BASE
        print(f"โหลดโมเดลเริ่มต้น: {DEFAULT_MODEL_BASE}")
    
    def _load_f5tts_model(self, ckpt_path, vocab_path=VOCAB_BASE):
        """โหลดโมเดล F5-TTS"""
        vocab_file = vocab_path if os.path.exists(VOCAB_BASE) else str(cached_path(VOCAB_HF))
        model = load_model(
            DiT, 
            F5TTS_MODEL_CFG, 
            ckpt_path, 
            vocab_file=vocab_file, 
            use_ema=True
        )
        print(f"โหลดโมเดลจาก {ckpt_path}")
        return model
    
    def load_model_by_choice(self, model_choice, custom_path=None):
        """โหลดโมเดลตามตัวเลือก"""
        torch.cuda.empty_cache()
        
        try:
            if model_choice == "Custom":
                if not custom_path:
                    raise ValueError("กรุณาระบุตำแหน่งโมเดลแบบกำหนดเอง")
                self.f5tts_model = self._load_f5tts_model(str(cached_path(custom_path)))
                self.current_model_path = custom_path
                return f"โหลดโมเดลแบบกำหนดเอง: {custom_path}"
            
            elif model_choice == "FP16":
                self.f5tts_model = self._load_f5tts_model(str(cached_path(FP16_MODEL_BASE)))
                self.current_model_path = FP16_MODEL_BASE
                return f"โหลดโมเดล FP16: {FP16_MODEL_BASE}"
            
            else:  # Default
                self.f5tts_model = self._load_f5tts_model(str(cached_path(DEFAULT_MODEL_BASE)))
                self.current_model_path = DEFAULT_MODEL_BASE
                return f"โหลดโมเดลเริ่มต้น: {DEFAULT_MODEL_BASE}"
                
        except Exception as e:
            error_msg = f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}"
            print(error_msg)
            return error_msg
    
    def get_model(self):
        """ดึงโมเดล F5-TTS ปัจจุบัน"""
        if self.f5tts_model is None:
            self.load_default_model()
        return self.f5tts_model
    
    def get_vocoder(self):
        """ดึง vocoder"""
        return self.vocoder
    
    def get_current_model_info(self):
        """ดึงข้อมูลโมเดลปัจจุบัน"""
        return {
            "model_path": self.current_model_path,
            "is_loaded": self.f5tts_model is not None
        }
    
    def update_custom_model_visibility(self, selected_model):
        """อัปเดตการแสดงผลของ custom model input"""
        return selected_model == "Custom"