trainer02 / miner.py
aiseosae's picture
Upload folder using huggingface_hub
618819a verified
Raw
History Blame Contribute Delete
8.07 kB
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
CODE_START_TOKEN_ID = 128257
CODE_END_TOKEN_ID = 128258
CODE_TOKEN_OFFSET = 128266
SNAC_MIN_ID = 128266
SNAC_MAX_ID = 156937
SNAC_TOKENS_PER_FRAME = 7
SOH_ID = 128259
EOH_ID = 128260
SOA_ID = 128261
BOS_ID = 128000
TEXT_EOT_ID = 128009
def build_prompt(tokenizer, description: str, text: str) -> str:
"""Build formatted prompt for Maya1."""
soh_token = tokenizer.decode([SOH_ID])
eoh_token = tokenizer.decode([EOH_ID])
soa_token = tokenizer.decode([SOA_ID])
sos_token = tokenizer.decode([CODE_START_TOKEN_ID])
eot_token = tokenizer.decode([TEXT_EOT_ID])
bos_token = tokenizer.bos_token
formatted_text = f'<description="{description}"> {text}'
prompt = (
soh_token + bos_token + formatted_text + eot_token +
eoh_token + soa_token + sos_token
)
return prompt
def extract_snac_codes(token_ids: list) -> list:
"""Extract SNAC codes from generated tokens."""
try:
eos_idx = token_ids.index(CODE_END_TOKEN_ID)
except ValueError:
eos_idx = len(token_ids)
snac_codes = [
token_id for token_id in token_ids[:eos_idx]
if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID
]
return snac_codes
def unpack_snac_from_7(snac_tokens: list) -> list:
"""Unpack 7-token SNAC frames to 3 hierarchical levels."""
if snac_tokens and snac_tokens[-1] == CODE_END_TOKEN_ID:
snac_tokens = snac_tokens[:-1]
frames = len(snac_tokens) // SNAC_TOKENS_PER_FRAME
snac_tokens = snac_tokens[:frames * SNAC_TOKENS_PER_FRAME]
if frames == 0:
return [[], [], []]
l1, l2, l3 = [], [], []
for i in range(frames):
slots = snac_tokens[i*7:(i+1)*7]
l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096)
l2.extend([
(slots[1] - CODE_TOKEN_OFFSET) % 4096,
(slots[4] - CODE_TOKEN_OFFSET) % 4096,
])
l3.extend([
(slots[2] - CODE_TOKEN_OFFSET) % 4096,
(slots[3] - CODE_TOKEN_OFFSET) % 4096,
(slots[5] - CODE_TOKEN_OFFSET) % 4096,
(slots[6] - CODE_TOKEN_OFFSET) % 4096,
])
return [l1, l2, l3]
def format_description(description: str) -> str:
parts = description.strip().split("|")
data = {}
# Parse into dict
for part in parts:
if ":" in part:
key, value = part.split(":", 1)
data[key.strip()] = value.strip()
# Build components
gender = data.get("gender", "")
age_group = data.get("age_group", "")
accent = data.get("accent", "")
pitch = data.get("pitch", "")
speed = data.get("speed", "")
emotion = data.get("emotion", "")
tone = data.get("tone", "")
# Convert to natural language
sentence1 = f"Realistic {gender} voice"
if age_group == "senior":
sentence1 += " in the 40s age"
elif age_group == "adult":
sentence1 += " in the 30s age"
elif age_group == "young_adult":
sentence1 += " in the 20s age"
else:
sentence1 += " in the 20s age"
if accent:
if accent.lower() == "us":
accent = "American"
elif accent.lower() == "uk":
accent = "British"
elif accent.lower() == "au":
accent = "Australian"
elif accent.lower() == "in":
accent = "Indian"
elif accent.lower() == "neutral":
accent = "Asian American"
elif accent.lower() == "other":
accent = "American"
sentence1 += f" with {accent.lower()} accent"
sentence2_parts = []
if pitch:
sentence2_parts.append(f"{pitch.capitalize()} pitch")
if emotion:
# Emotion: neutral, energetic, excited, sad, sarcastic, dry
if emotion.lower() == "happy":
emotion = "energetic"
elif emotion.lower() == "angry":
emotion = "sarcastic"
elif emotion.lower() == "calm":
emotion = "neutral"
elif emotion.lower() == "serious":
emotion = "dry"
elif emotion.lower() == "fearful":
emotion = "sad"
sentence2_parts.append(f"{emotion} timbre")
if speed:
if speed.lower() == "normal":
speed = "conversational"
sentence2_parts.append(f"{speed} pacing")
if tone:
# Timbre: `deep`, `warm`, `gravelly`, `smooth`, `raspy`, `nasally`, `throaty`, `harsh`
if tone.lower() == "cold":
tone = "harsh"
elif tone.lower() == "friendly":
tone = "warm"
elif tone.lower() == "formal":
tone = "smooth"
elif tone.lower() == "casual":
tone = "gravelly"
elif tone.lower() == "authoritative":
tone = "throaty"
sentence2_parts.append(f"{tone} tone")
sentence2 = ", ".join(sentence2_parts)
return sentence1 + ". " + sentence2 + "."
class Miner:
"""Vocence miner wrapper for Maya + SNAC inference."""
def __init__(self, path_hf_repo: Path) -> None:
self._repo_path = Path(path_hf_repo).resolve()
self._device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = AutoModelForCausalLM.from_pretrained(
str(self._repo_path),
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
self.tokenizer = AutoTokenizer.from_pretrained(
str(self._repo_path),
trust_remote_code=True,
)
snac_path = self._repo_path / "snac_model"
if snac_path.exists():
self.snac_model = SNAC.from_pretrained(str(snac_path)).eval()
else:
self.snac_model = SNAC.from_pretrained("snac_model").eval()
if torch.cuda.is_available():
self.snac_model = self.snac_model.to("cuda")
def warmup(self) -> None:
_ = self.generate_wav(
instruction="| gender: male | pitch: mid | speed: normal | age_group: adult | emotion: calm | tone: formal | accent: us",
text="This is a warmup utterance for the voice engine.",
)
def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
description = format_description(instruction)
prompt = build_prompt(self.tokenizer, description, text)
inputs = self.tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.inference_mode():
outputs = self.model.generate(
**inputs,
max_new_tokens=2048,
min_new_tokens=28,
temperature=0.4,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
eos_token_id=CODE_END_TOKEN_ID,
pad_token_id=self.tokenizer.pad_token_id,
)
generated_ids = outputs[0, inputs["input_ids"].shape[1] :].tolist()
snac_tokens = extract_snac_codes(generated_ids)
if len(snac_tokens) < SNAC_TOKENS_PER_FRAME:
raise RuntimeError("Not enough SNAC tokens generated for decoding.")
levels = unpack_snac_from_7(snac_tokens)
codes_tensor = [
torch.tensor(level, dtype=torch.long, device=self._device).unsqueeze(0)
for level in levels
]
with torch.inference_mode():
z_q = self.snac_model.quantizer.from_codes(codes_tensor)
audio = self.snac_model.decoder(z_q)[0, 0].cpu().numpy()
if len(audio) > 2048:
audio = audio[2048:]
return audio.astype(np.float32), 24000