File size: 7,374 Bytes
bca11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""
HuggingFace-compatible model loader for Romanian Matcha-TTS
"""

import json
import os
import torch
from pathlib import Path
from typing import Optional, Dict, Any

try:
    from huggingface_hub import hf_hub_download
    HF_AVAILABLE = True
except ImportError:
    HF_AVAILABLE = False


class ModelLoader:
    """
    HuggingFace-compatible loader for Romanian Matcha-TTS models

    Usage:
        loader = ModelLoader.from_pretrained("adrianstanea/Ro-Matcha-TTS")
        model, vocoder = loader.load_models(speaker="BAS")
    """

    def __init__(self, repo_path: str):
        """
        Initialize with local repository path or HuggingFace repo ID

        Args:
            repo_path: Path to local repo or HuggingFace repo ID
        """
        self.repo_path = repo_path
        self.config = self._load_config()

    @classmethod
    def from_pretrained(cls, repo_id: str, cache_dir: Optional[str] = None) -> "ModelLoader":
        """
        Load from HuggingFace Hub or local path

        Args:
            repo_id: HuggingFace repo ID (e.g., "adrianstanea/Ro-Matcha-TTS") or local path
            cache_dir: Optional cache directory for downloads

        Returns:
            ModelLoader instance
        """
        if os.path.exists(repo_id):
            # Local path
            return cls(repo_id)
        elif HF_AVAILABLE:
            # Download from HuggingFace Hub
            try:
                config_path = hf_hub_download(
                    repo_id=repo_id,
                    filename="configs/config.json",
                    cache_dir=cache_dir
                )
                repo_cache_path = Path(config_path).parent.parent
                return cls(str(repo_cache_path))
            except Exception as e:
                raise ValueError(f"Could not download from HuggingFace Hub: {e}")
        else:
            raise ImportError("huggingface_hub is required for downloading from HF Hub. Install with: pip install huggingface_hub")

    def _load_config(self) -> Dict[str, Any]:
        """Load model configuration"""
        config_path = os.path.join(self.repo_path, "configs", "config.json")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Config file not found at {config_path}")

        with open(config_path, 'r') as f:
            return json.load(f)

    def get_model_path(self, model: str = None) -> str:
        """
        Get path to model checkpoint for specified model

        Args:
            model: Model name (swara, bas_10, bas_950, sgs_10, sgs_950). If None, uses default.

        Returns:
            Absolute path to model checkpoint
        """
        if model is None:
            model = self.config["default_model"]

        if model not in self.config["available_models"]:
            available = list(self.config["available_models"].keys())
            raise ValueError(f"Model '{model}' not available. Available: {available}")

        model_file = self.config["available_models"][model]["file"]
        model_path = os.path.join(self.repo_path, model_file)

        if not os.path.exists(model_path):
            # Try to download from HuggingFace if not local
            if HF_AVAILABLE and not os.path.exists(self.repo_path):
                try:
                    model_path = hf_hub_download(
                        repo_id=self.repo_path,  # Treat as repo_id if not local path
                        filename=model_file
                    )
                except Exception as e:
                    raise FileNotFoundError(f"Model file not found locally and could not download: {e}")
            else:
                raise FileNotFoundError(f"Model file not found: {model_path}")

        return model_path

    def get_vocoder_path(self) -> str:
        """
        Get path to vocoder checkpoint

        Returns:
            Absolute path to vocoder checkpoint
        """
        vocoder_file = self.config["available_models"]["vocoder"]["file"]
        vocoder_path = os.path.join(self.repo_path, vocoder_file)

        if not os.path.exists(vocoder_path):
            # Try to download from HuggingFace if not local
            if HF_AVAILABLE and not os.path.exists(self.repo_path):
                try:
                    vocoder_path = hf_hub_download(
                        repo_id=self.repo_path,
                        filename=vocoder_file
                    )
                except Exception as e:
                    raise FileNotFoundError(f"Vocoder file not found locally and could not download: {e}")
            else:
                raise FileNotFoundError(f"Vocoder file not found: {vocoder_path}")

        return vocoder_path

    def load_models(self, model: str = None, device: str = "auto"):
        """
        Load TTS model and vocoder for inference

        NOTE: This returns paths for use with the original Matcha-TTS repository.
        You'll need to import and use the original loading functions.

        Args:
            model: Model to load (swara, bas_10, bas_950, sgs_10, sgs_950)
            device: Device to load on ("auto", "cpu", "cuda")

        Returns:
            Dict with model and vocoder paths and configurations
        """
        if device == "auto":
            device = "cuda" if torch.cuda.is_available() else "cpu"

        model_path = self.get_model_path(model)
        vocoder_path = self.get_vocoder_path()

        model_name = model or self.config["default_model"]
        model_info = self.config["available_models"][model_name]

        return {
            "model_path": model_path,
            "vocoder_path": vocoder_path,
            "config": self.config,
            "model_name": model_name,
            "model_info": model_info,
            "device": device,
            "inference_params": self.config["inference_defaults"]
        }

    def list_models(self):
        """List available models with details"""
        models = {}
        for name, info in self.config["available_models"].items():
            if name != "vocoder":
                models[name] = {
                    "type": info["type"],
                    "description": info["description"],
                    "speaker": info.get("speaker", "multi_speaker"),
                    "training_data": info.get("training_data", "N/A")
                }
        return models

    def list_research_variants(self):
        """List research comparison variants"""
        return self.config["research_variants"]

    def get_model_info(self, model: str = None):
        """Get detailed information about a specific model"""
        model_name = model or self.config["default_model"]
        if model_name not in self.config["available_models"]:
            raise ValueError(f"Model '{model_name}' not available")

        return self.config["available_models"][model_name]

    def get_sample_texts(self) -> list:
        """Get Romanian sample texts for testing"""
        return [
            "Bună ziua! Acesta este un test de sinteză vocală în limba română.",
            "Matcha-TTS funcționează foarte bine pentru limba română.",
            "Sistemul de sinteză vocală poate genera vorbire naturală.",
            "Această tehnologie folosește inteligența artificială avansată.",
            "Vorbirea sintetizată sună foarte realistă și naturală."
        ]