|
|
from __future__ import annotations |
|
|
|
|
|
import sys |
|
|
from dataclasses import dataclass, field |
|
|
from pathlib import Path |
|
|
from typing import List, Mapping, Optional, Sequence, Tuple |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class SamplingConfig: |
|
|
temperature: float = 0.8 |
|
|
top_k: int = 50 |
|
|
|
|
|
|
|
|
def _default_text_sampling() -> SamplingConfig: |
|
|
return SamplingConfig(temperature=0.6, top_k=50) |
|
|
|
|
|
|
|
|
def _default_audio_sampling() -> SamplingConfig: |
|
|
return SamplingConfig(temperature=0.8, top_k=50) |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class PrefixConfig: |
|
|
speaker_1: Optional[str] = None |
|
|
speaker_2: Optional[str] = None |
|
|
include_audio: bool = False |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class GenerationConfig: |
|
|
text: SamplingConfig = field(default_factory=_default_text_sampling) |
|
|
audio: SamplingConfig = field(default_factory=_default_audio_sampling) |
|
|
cfg_scale: float = 2.0 |
|
|
cfg_filter_k: int = 50 |
|
|
initial_padding: int = 2 |
|
|
prefix: Optional["PrefixConfig"] = None |
|
|
use_cuda_graph: bool = False |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class GenerationResult: |
|
|
audio_tokens: torch.Tensor |
|
|
waveform: torch.Tensor |
|
|
sample_rate: int |
|
|
timestamps: List[Tuple[str, float]] |
|
|
|
|
|
|
|
|
def normalize_script(script: str | Sequence[str]) -> str: |
|
|
if isinstance(script, str): |
|
|
return script.strip() |
|
|
return "\n".join(line.strip() for line in script) |
|
|
|
|
|
|
|
|
def load_script_text(path: str | Path) -> str: |
|
|
if path == "-": |
|
|
return sys.stdin.read().strip() |
|
|
path_obj = Path(path) |
|
|
if path_obj.exists(): |
|
|
return path_obj.read_text().strip() |
|
|
return str(path).strip() |
|
|
|
|
|
|
|
|
def validate_generation_params( |
|
|
*, |
|
|
temperature: float, |
|
|
top_k: int, |
|
|
cfg_scale: float, |
|
|
) -> tuple[float, int, float]: |
|
|
if temperature <= 0: |
|
|
raise ValueError("temperature must be positive") |
|
|
if top_k <= 0: |
|
|
raise ValueError("top_k must be positive") |
|
|
if cfg_scale <= 0: |
|
|
raise ValueError("cfg_scale must be positive") |
|
|
return temperature, top_k, cfg_scale |
|
|
|
|
|
|
|
|
def build_generation_config( |
|
|
*, |
|
|
temperature: float, |
|
|
top_k: int, |
|
|
cfg_scale: float, |
|
|
) -> GenerationConfig: |
|
|
sampling = SamplingConfig(temperature=temperature, top_k=top_k) |
|
|
return GenerationConfig( |
|
|
text=sampling, |
|
|
audio=sampling, |
|
|
cfg_scale=cfg_scale, |
|
|
) |
|
|
|
|
|
|
|
|
def merge_generation_config( |
|
|
*, |
|
|
base: GenerationConfig, |
|
|
overrides: Mapping[str, object], |
|
|
) -> GenerationConfig: |
|
|
clean_overrides = {k: v for k, v in overrides.items() if v is not None} |
|
|
text_temp = clean_overrides.pop("temp_text", None) |
|
|
text_topk = clean_overrides.pop("topk_text", None) |
|
|
audio_temp = clean_overrides.pop("temp_audio", None) |
|
|
audio_topk = clean_overrides.pop("topk_audio", None) |
|
|
prefix_speaker_1 = clean_overrides.pop("prefix_speaker_1", None) |
|
|
prefix_speaker_2 = clean_overrides.pop("prefix_speaker_2", None) |
|
|
include_prefix = clean_overrides.pop("include_prefix", None) |
|
|
|
|
|
text_sampling = base.text |
|
|
if text_temp is not None or text_topk is not None: |
|
|
text_sampling = SamplingConfig( |
|
|
temperature=text_temp if text_temp is not None else text_sampling.temperature, |
|
|
top_k=text_topk if text_topk is not None else text_sampling.top_k, |
|
|
) |
|
|
|
|
|
audio_sampling = base.audio |
|
|
if audio_temp is not None or audio_topk is not None: |
|
|
audio_sampling = SamplingConfig( |
|
|
temperature=audio_temp if audio_temp is not None else audio_sampling.temperature, |
|
|
top_k=audio_topk if audio_topk is not None else audio_sampling.top_k, |
|
|
) |
|
|
|
|
|
prefix_cfg = base.prefix |
|
|
if ( |
|
|
prefix_speaker_1 is not None |
|
|
or prefix_speaker_2 is not None |
|
|
or include_prefix is not None |
|
|
or prefix_cfg is not None |
|
|
): |
|
|
prefix_cfg = prefix_cfg or PrefixConfig() |
|
|
prefix_cfg = PrefixConfig( |
|
|
speaker_1=prefix_speaker_1 if prefix_speaker_1 is not None else prefix_cfg.speaker_1, |
|
|
speaker_2=prefix_speaker_2 if prefix_speaker_2 is not None else prefix_cfg.speaker_2, |
|
|
include_audio=include_prefix if include_prefix is not None else prefix_cfg.include_audio, |
|
|
) |
|
|
|
|
|
return GenerationConfig( |
|
|
text=text_sampling, |
|
|
audio=audio_sampling, |
|
|
cfg_scale=clean_overrides.pop("cfg_scale", base.cfg_scale), |
|
|
cfg_filter_k=clean_overrides.pop("cfg_filter_k", base.cfg_filter_k), |
|
|
initial_padding=clean_overrides.pop("initial_padding", base.initial_padding), |
|
|
prefix=prefix_cfg, |
|
|
use_cuda_graph=clean_overrides.pop("use_cuda_graph", base.use_cuda_graph), |
|
|
) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"SamplingConfig", |
|
|
"GenerationConfig", |
|
|
"GenerationResult", |
|
|
"PrefixConfig", |
|
|
"normalize_script", |
|
|
"load_script_text", |
|
|
"validate_generation_params", |
|
|
"build_generation_config", |
|
|
"merge_generation_config", |
|
|
] |
|
|
|