from __future__ import annotations from pathlib import Path import numpy as np import torch from snac import SNAC from transformers import AutoModelForCausalLM, AutoTokenizer CODE_START_TOKEN_ID = 128257 CODE_END_TOKEN_ID = 128258 CODE_TOKEN_OFFSET = 128266 SNAC_MIN_ID = 128266 SNAC_MAX_ID = 156937 SNAC_TOKENS_PER_FRAME = 7 SOH_ID = 128259 EOH_ID = 128260 SOA_ID = 128261 BOS_ID = 128000 TEXT_EOT_ID = 128009 def build_prompt(tokenizer, description: str, text: str) -> str: """Build formatted prompt for Maya1.""" soh_token = tokenizer.decode([SOH_ID]) eoh_token = tokenizer.decode([EOH_ID]) soa_token = tokenizer.decode([SOA_ID]) sos_token = tokenizer.decode([CODE_START_TOKEN_ID]) eot_token = tokenizer.decode([TEXT_EOT_ID]) bos_token = tokenizer.bos_token formatted_text = f' {text}' prompt = ( soh_token + bos_token + formatted_text + eot_token + eoh_token + soa_token + sos_token ) return prompt def extract_snac_codes(token_ids: list) -> list: """Extract SNAC codes from generated tokens.""" try: eos_idx = token_ids.index(CODE_END_TOKEN_ID) except ValueError: eos_idx = len(token_ids) snac_codes = [ token_id for token_id in token_ids[:eos_idx] if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID ] return snac_codes def unpack_snac_from_7(snac_tokens: list) -> list: """Unpack 7-token SNAC frames to 3 hierarchical levels.""" if snac_tokens and snac_tokens[-1] == CODE_END_TOKEN_ID: snac_tokens = snac_tokens[:-1] frames = len(snac_tokens) // SNAC_TOKENS_PER_FRAME snac_tokens = snac_tokens[:frames * SNAC_TOKENS_PER_FRAME] if frames == 0: return [[], [], []] l1, l2, l3 = [], [], [] for i in range(frames): slots = snac_tokens[i*7:(i+1)*7] l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096) l2.extend([ (slots[1] - CODE_TOKEN_OFFSET) % 4096, (slots[4] - CODE_TOKEN_OFFSET) % 4096, ]) l3.extend([ (slots[2] - CODE_TOKEN_OFFSET) % 4096, (slots[3] - CODE_TOKEN_OFFSET) % 4096, (slots[5] - CODE_TOKEN_OFFSET) % 4096, (slots[6] - CODE_TOKEN_OFFSET) % 4096, ]) return [l1, l2, l3] def format_description(description: str) -> str: parts = description.strip().split("|") data = {} # Parse into dict for part in parts: if ":" in part: key, value = part.split(":", 1) data[key.strip()] = value.strip() # Build components gender = data.get("gender", "") age_group = data.get("age_group", "") accent = data.get("accent", "") pitch = data.get("pitch", "") speed = data.get("speed", "") emotion = data.get("emotion", "") tone = data.get("tone", "") # Convert to natural language sentence1 = f"Realistic {gender} voice" if age_group == "senior": sentence1 += " in the 40s age" elif age_group == "adult": sentence1 += " in the 30s age" elif age_group == "young_adult": sentence1 += " in the 20s age" else: sentence1 += " in the 20s age" if accent: if accent.lower() == "us": accent = "American" elif accent.lower() == "uk": accent = "British" elif accent.lower() == "au": accent = "Australian" elif accent.lower() == "in": accent = "Indian" elif accent.lower() == "neutral": accent = "Asian American" elif accent.lower() == "other": accent = "American" sentence1 += f" with {accent.lower()} accent" sentence2_parts = [] if pitch: sentence2_parts.append(f"{pitch.capitalize()} pitch") if emotion: # Emotion: neutral, energetic, excited, sad, sarcastic, dry if emotion.lower() == "happy": emotion = "energetic" elif emotion.lower() == "angry": emotion = "sarcastic" elif emotion.lower() == "calm": emotion = "neutral" elif emotion.lower() == "serious": emotion = "dry" elif emotion.lower() == "fearful": emotion = "sad" sentence2_parts.append(f"{emotion} timbre") if speed: if speed.lower() == "normal": speed = "conversational" sentence2_parts.append(f"{speed} pacing") if tone: # Timbre: `deep`, `warm`, `gravelly`, `smooth`, `raspy`, `nasally`, `throaty`, `harsh` if tone.lower() == "cold": tone = "harsh" elif tone.lower() == "friendly": tone = "warm" elif tone.lower() == "formal": tone = "smooth" elif tone.lower() == "casual": tone = "gravelly" elif tone.lower() == "authoritative": tone = "throaty" sentence2_parts.append(f"{tone} tone") sentence2 = ", ".join(sentence2_parts) return sentence1 + ". " + sentence2 + "." class Miner: """Vocence miner wrapper for Maya + SNAC inference.""" def __init__(self, path_hf_repo: Path) -> None: self._repo_path = Path(path_hf_repo).resolve() self._device = "cuda" if torch.cuda.is_available() else "cpu" self.model = AutoModelForCausalLM.from_pretrained( str(self._repo_path), torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) self.tokenizer = AutoTokenizer.from_pretrained( str(self._repo_path), trust_remote_code=True, ) snac_path = self._repo_path / "snac_model" if snac_path.exists(): self.snac_model = SNAC.from_pretrained(str(snac_path)).eval() else: self.snac_model = SNAC.from_pretrained("snac_model").eval() if torch.cuda.is_available(): self.snac_model = self.snac_model.to("cuda") def warmup(self) -> None: _ = self.generate_wav( instruction="| gender: male | pitch: mid | speed: normal | age_group: adult | emotion: calm | tone: formal | accent: us", text="This is a warmup utterance for the voice engine.", ) def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]: description = format_description(instruction) prompt = build_prompt(self.tokenizer, description, text) inputs = self.tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} with torch.inference_mode(): outputs = self.model.generate( **inputs, max_new_tokens=2048, min_new_tokens=28, temperature=0.4, top_p=0.9, repetition_penalty=1.1, do_sample=True, eos_token_id=CODE_END_TOKEN_ID, pad_token_id=self.tokenizer.pad_token_id, ) generated_ids = outputs[0, inputs["input_ids"].shape[1] :].tolist() snac_tokens = extract_snac_codes(generated_ids) if len(snac_tokens) < SNAC_TOKENS_PER_FRAME: raise RuntimeError("Not enough SNAC tokens generated for decoding.") levels = unpack_snac_from_7(snac_tokens) codes_tensor = [ torch.tensor(level, dtype=torch.long, device=self._device).unsqueeze(0) for level in levels ] with torch.inference_mode(): z_q = self.snac_model.quantizer.from_codes(codes_tensor) audio = self.snac_model.decoder(z_q)[0, 0].cpu().numpy() if len(audio) > 2048: audio = audio[2048:] return audio.astype(np.float32), 24000