File size: 4,810 Bytes
1315cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from __future__ import annotations

import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Mapping, Optional, Sequence, Tuple

import torch


@dataclass(frozen=True)
class SamplingConfig:
    temperature: float = 0.8
    top_k: int = 50


def _default_text_sampling() -> SamplingConfig:
    return SamplingConfig(temperature=0.6, top_k=50)


def _default_audio_sampling() -> SamplingConfig:
    return SamplingConfig(temperature=0.8, top_k=50)


@dataclass(frozen=True)
class PrefixConfig:
    speaker_1: Optional[str] = None
    speaker_2: Optional[str] = None
    include_audio: bool = False


@dataclass(frozen=True)
class GenerationConfig:
    text: SamplingConfig = field(default_factory=_default_text_sampling)
    audio: SamplingConfig = field(default_factory=_default_audio_sampling)
    cfg_scale: float = 2.0
    cfg_filter_k: int = 50
    initial_padding: int = 2
    prefix: Optional["PrefixConfig"] = None
    use_cuda_graph: bool = False


@dataclass(frozen=True)
class GenerationResult:
    audio_tokens: torch.Tensor
    waveform: torch.Tensor
    sample_rate: int
    timestamps: List[Tuple[str, float]]


def normalize_script(script: str | Sequence[str]) -> str:
    if isinstance(script, str):
        return script.strip()
    return "\n".join(line.strip() for line in script)


def load_script_text(path: str | Path) -> str:
    if path == "-":
        return sys.stdin.read().strip()
    path_obj = Path(path)
    if path_obj.exists():
        return path_obj.read_text().strip()
    return str(path).strip()


def validate_generation_params(
    *,
    temperature: float,
    top_k: int,
    cfg_scale: float,
) -> tuple[float, int, float]:
    if temperature <= 0:
        raise ValueError("temperature must be positive")
    if top_k <= 0:
        raise ValueError("top_k must be positive")
    if cfg_scale <= 0:
        raise ValueError("cfg_scale must be positive")
    return temperature, top_k, cfg_scale


def build_generation_config(
    *,
    temperature: float,
    top_k: int,
    cfg_scale: float,
) -> GenerationConfig:
    sampling = SamplingConfig(temperature=temperature, top_k=top_k)
    return GenerationConfig(
        text=sampling,
        audio=sampling,
        cfg_scale=cfg_scale,
    )


def merge_generation_config(
    *,
    base: GenerationConfig,
    overrides: Mapping[str, object],
) -> GenerationConfig:
    clean_overrides = {k: v for k, v in overrides.items() if v is not None}
    text_temp = clean_overrides.pop("temp_text", None)
    text_topk = clean_overrides.pop("topk_text", None)
    audio_temp = clean_overrides.pop("temp_audio", None)
    audio_topk = clean_overrides.pop("topk_audio", None)
    prefix_speaker_1 = clean_overrides.pop("prefix_speaker_1", None)
    prefix_speaker_2 = clean_overrides.pop("prefix_speaker_2", None)
    include_prefix = clean_overrides.pop("include_prefix", None)

    text_sampling = base.text
    if text_temp is not None or text_topk is not None:
        text_sampling = SamplingConfig(
            temperature=text_temp if text_temp is not None else text_sampling.temperature,
            top_k=text_topk if text_topk is not None else text_sampling.top_k,
        )

    audio_sampling = base.audio
    if audio_temp is not None or audio_topk is not None:
        audio_sampling = SamplingConfig(
            temperature=audio_temp if audio_temp is not None else audio_sampling.temperature,
            top_k=audio_topk if audio_topk is not None else audio_sampling.top_k,
        )

    prefix_cfg = base.prefix
    if (
        prefix_speaker_1 is not None
        or prefix_speaker_2 is not None
        or include_prefix is not None
        or prefix_cfg is not None
    ):
        prefix_cfg = prefix_cfg or PrefixConfig()
        prefix_cfg = PrefixConfig(
            speaker_1=prefix_speaker_1 if prefix_speaker_1 is not None else prefix_cfg.speaker_1,
            speaker_2=prefix_speaker_2 if prefix_speaker_2 is not None else prefix_cfg.speaker_2,
            include_audio=include_prefix if include_prefix is not None else prefix_cfg.include_audio,
        )

    return GenerationConfig(
        text=text_sampling,
        audio=audio_sampling,
        cfg_scale=clean_overrides.pop("cfg_scale", base.cfg_scale),
        cfg_filter_k=clean_overrides.pop("cfg_filter_k", base.cfg_filter_k),
        initial_padding=clean_overrides.pop("initial_padding", base.initial_padding),
        prefix=prefix_cfg,
        use_cuda_graph=clean_overrides.pop("use_cuda_graph", base.use_cuda_graph),
    )


__all__ = [
    "SamplingConfig",
    "GenerationConfig",
    "GenerationResult",
    "PrefixConfig",
    "normalize_script",
    "load_script_text",
    "validate_generation_params",
    "build_generation_config",
    "merge_generation_config",
]