Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,810 Bytes
1315cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from __future__ import annotations
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Mapping, Optional, Sequence, Tuple
import torch
@dataclass(frozen=True)
class SamplingConfig:
temperature: float = 0.8
top_k: int = 50
def _default_text_sampling() -> SamplingConfig:
return SamplingConfig(temperature=0.6, top_k=50)
def _default_audio_sampling() -> SamplingConfig:
return SamplingConfig(temperature=0.8, top_k=50)
@dataclass(frozen=True)
class PrefixConfig:
speaker_1: Optional[str] = None
speaker_2: Optional[str] = None
include_audio: bool = False
@dataclass(frozen=True)
class GenerationConfig:
text: SamplingConfig = field(default_factory=_default_text_sampling)
audio: SamplingConfig = field(default_factory=_default_audio_sampling)
cfg_scale: float = 2.0
cfg_filter_k: int = 50
initial_padding: int = 2
prefix: Optional["PrefixConfig"] = None
use_cuda_graph: bool = False
@dataclass(frozen=True)
class GenerationResult:
audio_tokens: torch.Tensor
waveform: torch.Tensor
sample_rate: int
timestamps: List[Tuple[str, float]]
def normalize_script(script: str | Sequence[str]) -> str:
if isinstance(script, str):
return script.strip()
return "\n".join(line.strip() for line in script)
def load_script_text(path: str | Path) -> str:
if path == "-":
return sys.stdin.read().strip()
path_obj = Path(path)
if path_obj.exists():
return path_obj.read_text().strip()
return str(path).strip()
def validate_generation_params(
*,
temperature: float,
top_k: int,
cfg_scale: float,
) -> tuple[float, int, float]:
if temperature <= 0:
raise ValueError("temperature must be positive")
if top_k <= 0:
raise ValueError("top_k must be positive")
if cfg_scale <= 0:
raise ValueError("cfg_scale must be positive")
return temperature, top_k, cfg_scale
def build_generation_config(
*,
temperature: float,
top_k: int,
cfg_scale: float,
) -> GenerationConfig:
sampling = SamplingConfig(temperature=temperature, top_k=top_k)
return GenerationConfig(
text=sampling,
audio=sampling,
cfg_scale=cfg_scale,
)
def merge_generation_config(
*,
base: GenerationConfig,
overrides: Mapping[str, object],
) -> GenerationConfig:
clean_overrides = {k: v for k, v in overrides.items() if v is not None}
text_temp = clean_overrides.pop("temp_text", None)
text_topk = clean_overrides.pop("topk_text", None)
audio_temp = clean_overrides.pop("temp_audio", None)
audio_topk = clean_overrides.pop("topk_audio", None)
prefix_speaker_1 = clean_overrides.pop("prefix_speaker_1", None)
prefix_speaker_2 = clean_overrides.pop("prefix_speaker_2", None)
include_prefix = clean_overrides.pop("include_prefix", None)
text_sampling = base.text
if text_temp is not None or text_topk is not None:
text_sampling = SamplingConfig(
temperature=text_temp if text_temp is not None else text_sampling.temperature,
top_k=text_topk if text_topk is not None else text_sampling.top_k,
)
audio_sampling = base.audio
if audio_temp is not None or audio_topk is not None:
audio_sampling = SamplingConfig(
temperature=audio_temp if audio_temp is not None else audio_sampling.temperature,
top_k=audio_topk if audio_topk is not None else audio_sampling.top_k,
)
prefix_cfg = base.prefix
if (
prefix_speaker_1 is not None
or prefix_speaker_2 is not None
or include_prefix is not None
or prefix_cfg is not None
):
prefix_cfg = prefix_cfg or PrefixConfig()
prefix_cfg = PrefixConfig(
speaker_1=prefix_speaker_1 if prefix_speaker_1 is not None else prefix_cfg.speaker_1,
speaker_2=prefix_speaker_2 if prefix_speaker_2 is not None else prefix_cfg.speaker_2,
include_audio=include_prefix if include_prefix is not None else prefix_cfg.include_audio,
)
return GenerationConfig(
text=text_sampling,
audio=audio_sampling,
cfg_scale=clean_overrides.pop("cfg_scale", base.cfg_scale),
cfg_filter_k=clean_overrides.pop("cfg_filter_k", base.cfg_filter_k),
initial_padding=clean_overrides.pop("initial_padding", base.initial_padding),
prefix=prefix_cfg,
use_cuda_graph=clean_overrides.pop("use_cuda_graph", base.use_cuda_graph),
)
__all__ = [
"SamplingConfig",
"GenerationConfig",
"GenerationResult",
"PrefixConfig",
"normalize_script",
"load_script_text",
"validate_generation_params",
"build_generation_config",
"merge_generation_config",
]
|