File size: 13,298 Bytes
35bb6f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
from __future__ import annotations

import asyncio
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import AsyncGenerator

import numpy as np
from loguru import logger

from api.src.core.config import settings
from api.src.core.model_config import (
    BACKBONE_MODELS,
    BackendType,
    get_backbone_info,
)


class ModelLoadStatus(str, Enum):
    PENDING = "pending"
    DOWNLOADING = "downloading"
    LOADING = "loading"
    READY = "ready"
    ERROR = "error"


@dataclass
class ModelLoadingTask:
    task_id: str
    model_id: str
    status: ModelLoadStatus = ModelLoadStatus.PENDING
    progress_message: str = ""
    error_message: str = ""
    started_at: float = 0.0
    completed_at: float = 0.0


@dataclass
class LoadedModel:
    model_id: str
    codec_id: str
    tts_instance: object  # NeuTTS instance
    lock: asyncio.Lock = field(default_factory=asyncio.Lock)
    backbone_device: str = "cpu"
    codec_device: str = "cpu"


class ModelManager:
    _instance: ModelManager | None = None

    def __init__(self) -> None:
        self._models: dict[str, LoadedModel] = {}
        self._loading_tasks: dict[str, ModelLoadingTask] = {}
        self._executor = ThreadPoolExecutor(max_workers=settings.max_inference_workers)

    @classmethod
    def get_instance(cls) -> ModelManager:
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance

    @property
    def loaded_models(self) -> dict[str, LoadedModel]:
        return self._models

    @property
    def loading_tasks(self) -> dict[str, ModelLoadingTask]:
        return self._loading_tasks

    def is_loaded(self, model_id: str) -> bool:
        return model_id in self._models

    def get_task(self, task_id: str) -> ModelLoadingTask | None:
        return self._loading_tasks.get(task_id)

    async def load_model_async(
        self,
        model_id: str,
        codec_id: str | None = None,
        backbone_device: str | None = None,
        codec_device: str | None = None,
    ) -> ModelLoadingTask:
        """Start loading a model in the background. Returns a task for polling."""
        # Already loaded -> return READY task immediately
        if model_id in self._models:
            task = ModelLoadingTask(
                task_id=str(uuid.uuid4()),
                model_id=model_id,
                status=ModelLoadStatus.READY,
                progress_message="Already loaded",
                started_at=time.time(),
                completed_at=time.time(),
            )
            self._loading_tasks[task.task_id] = task
            return task

        # Already loading -> return existing task
        for task in self._loading_tasks.values():
            if task.model_id == model_id and task.status in (
                ModelLoadStatus.PENDING,
                ModelLoadStatus.DOWNLOADING,
                ModelLoadStatus.LOADING,
            ):
                return task

        info = get_backbone_info(model_id)
        if info is None:
            raise ValueError(f"Unknown model: {model_id}. Available: {list(BACKBONE_MODELS.keys())}")

        task = ModelLoadingTask(
            task_id=str(uuid.uuid4()),
            model_id=model_id,
            status=ModelLoadStatus.PENDING,
            progress_message="Queued",
            started_at=time.time(),
        )
        self._loading_tasks[task.task_id] = task

        asyncio.ensure_future(
            self._background_load(task, codec_id, backbone_device, codec_device)
        )
        return task

    async def _background_load(
        self,
        task: ModelLoadingTask,
        codec_id: str | None,
        backbone_device: str | None,
        codec_device: str | None,
    ) -> None:
        """Background coroutine that loads a model and updates task status."""
        try:
            task.status = ModelLoadStatus.DOWNLOADING
            task.progress_message = "Downloading / checking cache..."

            info = get_backbone_info(task.model_id)
            if info is None:
                raise ValueError(f"Unknown model: {task.model_id}")

            codec = codec_id or settings.default_codec
            bb_device = backbone_device or settings.resolved_backbone_device
            cc_device = codec_device or settings.default_codec_device

            # GGUF models only support CPU (llama.cpp limitation)
            if info.backend == BackendType.GGUF:
                bb_device = "cpu"

            logger.info(
                f"[Task {task.task_id[:8]}] Loading {task.model_id} "
                f"(backbone_device={bb_device}, codec_device={cc_device})"
            )

            # Schedule status transition after 3s (heuristic for download vs load)
            async def _mark_loading() -> None:
                await asyncio.sleep(3)
                if task.status == ModelLoadStatus.DOWNLOADING:
                    task.status = ModelLoadStatus.LOADING
                    task.progress_message = "Initializing model..."

            timer_task = asyncio.ensure_future(_mark_loading())

            loop = asyncio.get_event_loop()
            tts = await loop.run_in_executor(
                self._executor,
                self._create_tts_instance,
                info.repo,
                codec,
                bb_device,
                cc_device,
            )

            timer_task.cancel()

            loaded = LoadedModel(
                model_id=task.model_id,
                codec_id=codec,
                tts_instance=tts,
                backbone_device=bb_device,
                codec_device=cc_device,
            )
            self._models[task.model_id] = loaded

            task.status = ModelLoadStatus.READY
            task.progress_message = "Model ready"
            task.completed_at = time.time()
            logger.info(f"[Task {task.task_id[:8]}] {task.model_id} loaded successfully")

        except Exception as e:
            task.status = ModelLoadStatus.ERROR
            task.error_message = str(e)
            task.progress_message = "Failed"
            task.completed_at = time.time()
            logger.error(f"[Task {task.task_id[:8]}] Failed to load {task.model_id}: {e}")

    async def load_model(
        self,
        model_id: str,
        codec_id: str | None = None,
        backbone_device: str | None = None,
        codec_device: str | None = None,
    ) -> LoadedModel:
        """Synchronous load (blocks until done). Used by startup."""
        if model_id in self._models:
            logger.info(f"Model {model_id} already loaded")
            return self._models[model_id]

        info = get_backbone_info(model_id)
        if info is None:
            raise ValueError(f"Unknown model: {model_id}. Available: {list(BACKBONE_MODELS.keys())}")

        codec = codec_id or settings.default_codec
        bb_device = backbone_device or settings.resolved_backbone_device
        cc_device = codec_device or settings.default_codec_device

        if info.backend == BackendType.GGUF:
            bb_device = "cpu"

        logger.info(
            f"Loading model {model_id} (repo={info.repo}, codec={codec}, "
            f"backbone_device={bb_device}, codec_device={cc_device})"
        )

        loop = asyncio.get_event_loop()
        tts = await loop.run_in_executor(
            self._executor,
            self._create_tts_instance,
            info.repo,
            codec,
            bb_device,
            cc_device,
        )

        loaded = LoadedModel(
            model_id=model_id,
            codec_id=codec,
            tts_instance=tts,
            backbone_device=bb_device,
            codec_device=cc_device,
        )
        self._models[model_id] = loaded
        logger.info(f"Model {model_id} loaded successfully")
        return loaded

    @staticmethod
    def _create_tts_instance(
        backbone_repo: str,
        codec_repo: str,
        backbone_device: str,
        codec_device: str,
    ) -> object:
        from neutts import NeuTTS

        return NeuTTS(
            backbone_repo=backbone_repo,
            backbone_device=backbone_device,
            codec_repo=codec_repo,
            codec_device=codec_device,
        )

    async def unload_model(self, model_id: str) -> None:
        if model_id not in self._models:
            raise ValueError(f"Model {model_id} is not loaded")

        loaded = self._models.pop(model_id)
        async with loaded.lock:
            del loaded.tts_instance
        logger.info(f"Model {model_id} unloaded")

    async def switch_device(
        self,
        model_id: str,
        backbone_device: str | None = None,
        codec_device: str | None = None,
    ) -> ModelLoadingTask:
        """Unload model and reload on a different device."""
        if model_id not in self._models:
            raise ValueError(f"Model {model_id} is not loaded")

        loaded = self._models[model_id]
        info = get_backbone_info(model_id)

        if info and info.backend == BackendType.GGUF:
            raise ValueError(
                f"Model {model_id} is GGUF (llama.cpp) and only supports CPU. "
                "Device switching is not available for GGUF models."
            )

        codec_id = loaded.codec_id
        bb_device = backbone_device or loaded.backbone_device
        cc_device = codec_device or loaded.codec_device

        logger.info(f"Switching {model_id} device to backbone={bb_device}, codec={cc_device}")
        await self.unload_model(model_id)

        return await self.load_model_async(
            model_id=model_id,
            codec_id=codec_id,
            backbone_device=bb_device,
            codec_device=cc_device,
        )

    def cleanup_old_tasks(self, max_age_seconds: float = 3600) -> int:
        """Remove completed/errored tasks older than max_age_seconds."""
        now = time.time()
        to_remove = [
            tid
            for tid, t in self._loading_tasks.items()
            if t.status in (ModelLoadStatus.READY, ModelLoadStatus.ERROR)
            and t.completed_at > 0
            and (now - t.completed_at) > max_age_seconds
        ]
        for tid in to_remove:
            del self._loading_tasks[tid]
        return len(to_remove)

    async def infer(
        self,
        model_id: str,
        text: str,
        ref_codes: object,
        ref_text: str,
    ) -> np.ndarray:
        loaded = self._get_loaded(model_id)

        async with loaded.lock:
            loop = asyncio.get_event_loop()
            wav = await loop.run_in_executor(
                self._executor,
                loaded.tts_instance.infer,
                text,
                ref_codes,
                ref_text,
            )
        return wav

    async def infer_stream(
        self,
        model_id: str,
        text: str,
        ref_codes: object,
        ref_text: str,
    ) -> AsyncGenerator[np.ndarray, None]:
        loaded = self._get_loaded(model_id)
        info = get_backbone_info(model_id)

        if info is None or not info.supports_streaming:
            raise ValueError(
                f"Model {model_id} does not support streaming. "
                "Only GGUF models support infer_stream()."
            )

        queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue()

        def _stream_worker() -> None:
            try:
                for chunk in loaded.tts_instance.infer_stream(text, ref_codes, ref_text):
                    queue.put_nowait(chunk)
            except Exception as e:
                logger.error(f"Streaming error for {model_id}: {e}")
            finally:
                queue.put_nowait(None)

        async with loaded.lock:
            loop = asyncio.get_event_loop()
            loop.run_in_executor(self._executor, _stream_worker)

            while True:
                chunk = await queue.get()
                if chunk is None:
                    break
                yield chunk

    async def encode_reference(self, model_id: str, audio_path: str) -> object:
        loaded = self._get_loaded(model_id)

        async with loaded.lock:
            loop = asyncio.get_event_loop()
            ref_codes = await loop.run_in_executor(
                self._executor,
                loaded.tts_instance.encode_reference,
                audio_path,
            )
        return ref_codes

    def _get_loaded(self, model_id: str) -> LoadedModel:
        loaded = self._models.get(model_id)
        if loaded is None:
            raise ValueError(
                f"Model {model_id} is not loaded. "
                f"Loaded models: {list(self._models.keys())}"
            )
        return loaded

    async def startup(self) -> None:
        for model_id in settings.default_models_list:
            try:
                await self.load_model(model_id)
            except Exception as e:
                logger.error(f"Failed to load default model {model_id}: {e}")

    async def shutdown(self) -> None:
        model_ids = list(self._models.keys())
        for model_id in model_ids:
            try:
                await self.unload_model(model_id)
            except Exception:
                pass
        self._executor.shutdown(wait=False)