anime-gen-api / scripts /setup_genshin_voices.py
AswinMathew's picture
Upload folder using huggingface_hub
7190fd0 verified
"""Extract Genshin Impact voice references and encode voice states for Pocket TTS.
Uses HuggingFace Datasets Server API to filter by speaker name directly.
One API call per character — no parquet scanning needed.
Resumable: skips characters that already have voice_state.safetensors.
Usage:
python scripts/setup_genshin_voices.py
"""
import json
import io
import sys
import time
from pathlib import Path
import numpy as np
import requests
import soundfile as sf
import scipy.io.wavfile
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
sys.stdout.reconfigure(encoding="utf-8")
from scripts.tag_genshin_voices import GENSHIN_CHARACTERS
HF_FILTER_URL = "https://datasets-server.huggingface.co/filter"
DATASET = "simon3000/genshin-voice"
MAX_RETRIES = 3
def _sanitize_name(name: str) -> str:
return name.lower().replace(" ", "_").replace(":", "").replace("'", "")
def _encode_voice_state(model, ref_path, state_path):
"""Encode a single voice state from reference WAV."""
from pocket_tts.models.tts_model import export_model_state
t0 = time.time()
state = model.get_state_for_audio_prompt(str(ref_path))
export_model_state(state, str(state_path))
return time.time() - t0
def _fetch_best_clip(speaker: str) -> tuple:
"""Fetch the longest English clip for a speaker via HF Datasets API.
Returns (audio_bytes, transcription, duration) or (None, None, 0) on failure.
"""
# Escape single quotes in speaker name
safe_speaker = speaker.replace("'", "\\'")
params = {
"dataset": DATASET,
"config": "default",
"split": "train",
"where": f"speaker='{safe_speaker}' AND language='English(US)'",
"length": 100,
}
for attempt in range(MAX_RETRIES):
try:
r = requests.get(HF_FILTER_URL, params=params, timeout=120)
if r.status_code != 200:
if attempt < MAX_RETRIES - 1:
time.sleep(2)
continue
return None, None, 0
data = r.json()
rows = data.get("rows", [])
if not rows:
return None, None, 0
# Find longest transcription (best reference clip)
best = max(rows, key=lambda row: len(row["row"].get("transcription", "") or ""))
trans = best["row"].get("transcription", "")
audio_list = best["row"].get("audio", [])
if not audio_list or not isinstance(audio_list, list):
return None, None, 0
audio_src = audio_list[0].get("src", "")
if not audio_src:
return None, None, 0
# Download audio
r2 = requests.get(audio_src, timeout=60)
if r2.status_code != 200:
if attempt < MAX_RETRIES - 1:
time.sleep(2)
continue
return None, None, 0
audio_data, sr = sf.read(io.BytesIO(r2.content))
duration = len(audio_data) / sr
return r2.content, trans, duration
except Exception as e:
if attempt < MAX_RETRIES - 1:
time.sleep(2)
continue
print(f"ERROR: {e}")
return None, None, 0
return None, None, 0
def main():
print("=" * 60)
print("Genshin Voice Setup -- Extract & Encode Voice References")
print("=" * 60)
voices_dir = PROJECT_ROOT / "data" / "genshin_voices"
catalog_path = PROJECT_ROOT / "data" / "genshin_voice_catalog.json"
voices_dir.mkdir(parents=True, exist_ok=True)
# Load Pocket TTS model upfront
print("\n[1/3] Loading Pocket TTS model...")
model = None
try:
from pocket_tts import TTSModel
model = TTSModel.load_model()
print(f" Model loaded (has_voice_cloning={model.has_voice_cloning})")
if not model.has_voice_cloning:
print(" WARNING: Voice cloning not available, will extract refs only")
model = None
except ImportError:
print(" WARNING: pocket-tts not installed, will extract refs only")
# Check which characters are fully done
already_done = set()
has_ref_only = set()
for char_name in GENSHIN_CHARACTERS:
char_id = _sanitize_name(char_name)
state_path = voices_dir / char_id / "voice_state.safetensors"
ref_path = voices_dir / char_id / "reference.wav"
if state_path.exists():
already_done.add(char_name)
elif ref_path.exists():
has_ref_only.add(char_name)
if already_done:
print(f"\n {len(already_done)} characters fully done (skipping)")
# Encode any characters that have reference.wav but no voice_state
if has_ref_only and model:
print(f"\n Encoding {len(has_ref_only)} existing references...")
for char_name in sorted(has_ref_only):
char_id = _sanitize_name(char_name)
ref_path = voices_dir / char_id / "reference.wav"
state_path = voices_dir / char_id / "voice_state.safetensors"
try:
elapsed = _encode_voice_state(model, ref_path, state_path)
already_done.add(char_name)
print(f" {char_name}: encoded in {elapsed:.1f}s")
except Exception as e:
print(f" {char_name}: FAILED - {e}")
# Determine what still needs extraction
targets = [name for name in GENSHIN_CHARACTERS if name not in already_done]
if not targets:
print("\n All characters done!")
_save_catalog(catalog_path, {}, voices_dir)
return
print(f"\n {len(targets)} characters still need extraction")
# Fetch and encode each character via HF API
print("\n[2/3] Fetching clips via HuggingFace Datasets API...")
extracted = {}
total_encoded = len(already_done)
failed = []
for i, char_name in enumerate(targets):
char_id = _sanitize_name(char_name)
char_dir = voices_dir / char_id
char_dir.mkdir(parents=True, exist_ok=True)
ref_path = char_dir / "reference.wav"
state_path = char_dir / "voice_state.safetensors"
print(f" [{i+1}/{len(targets)}] {char_name}...", end=" ", flush=True)
audio_bytes, trans, duration = _fetch_best_clip(char_name)
if audio_bytes is None:
print("NOT FOUND")
failed.append(char_name)
continue
# Save reference WAV
audio_data, sr = sf.read(io.BytesIO(audio_bytes))
samples = audio_data.astype(np.float32)
samples_int16 = np.clip(samples * 32767, -32768, 32767).astype(np.int16)
scipy.io.wavfile.write(str(ref_path), sr, samples_int16)
# Encode voice state immediately
if model:
try:
elapsed = _encode_voice_state(model, ref_path, state_path)
total_encoded += 1
print(f"{duration:.1f}s ref -> encoded in {elapsed:.1f}s [{total_encoded}/{len(GENSHIN_CHARACTERS)}]")
except Exception as e:
print(f"ref saved ({duration:.1f}s), encode FAILED - {e}")
else:
print(f"ref saved ({duration:.1f}s)")
extracted[char_name] = {
"path": str(ref_path),
"duration": duration,
"transcript": trans,
}
# Small delay to avoid rate limiting
time.sleep(0.5)
print(f"\n\n[3/3] Summary")
print(f" Extracted: {len(extracted)} new clips")
print(f" Total encoded: {total_encoded}/{len(GENSHIN_CHARACTERS)}")
if failed:
print(f" Failed ({len(failed)}): {', '.join(failed)}")
_save_catalog(catalog_path, extracted, voices_dir)
print(f"\nDone! {total_encoded} total voices ready")
def _save_catalog(catalog_path: Path, extracted: dict, voices_dir: Path):
"""Save voice catalog JSON."""
catalog = []
for char_name, info in GENSHIN_CHARACTERS.items():
char_id = _sanitize_name(char_name)
char_dir = voices_dir / char_id
ref_path = char_dir / "reference.wav"
state_path = char_dir / "voice_state.safetensors"
entry = {
"id": f"genshin_{char_id}",
"display_name": char_name,
"source_anime": "Genshin Impact",
"gender": info["gender"],
"age_category": info["age"],
"vocal_traits": info["traits"],
"pitch_range": "mid",
"clip_count": 500,
"reference_path": str(ref_path) if ref_path.exists() else None,
"voice_state_path": str(state_path) if state_path.exists() else None,
"preset_voice": None,
"archetype": info.get("archetype", ""),
}
if char_name in extracted:
entry["reference_transcript"] = extracted[char_name].get("transcript", "")
entry["reference_duration"] = extracted[char_name].get("duration", 0)
catalog.append(entry)
with open(catalog_path, "w", encoding="utf-8") as f:
json.dump(catalog, f, indent=2, ensure_ascii=False)
ready = sum(1 for e in catalog if e["voice_state_path"])
ref_only = sum(1 for e in catalog if e["reference_path"] and not e["voice_state_path"])
print(f" Catalog: {ready} ready (cloned), {ref_only} ref only, {len(catalog)} total")
if __name__ == "__main__":
main()