Spaces:
Sleeping
Sleeping
| """Extract Genshin Impact voice references and encode voice states for Pocket TTS. | |
| Uses HuggingFace Datasets Server API to filter by speaker name directly. | |
| One API call per character — no parquet scanning needed. | |
| Resumable: skips characters that already have voice_state.safetensors. | |
| Usage: | |
| python scripts/setup_genshin_voices.py | |
| """ | |
| import json | |
| import io | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import requests | |
| import soundfile as sf | |
| import scipy.io.wavfile | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| sys.stdout.reconfigure(encoding="utf-8") | |
| from scripts.tag_genshin_voices import GENSHIN_CHARACTERS | |
| HF_FILTER_URL = "https://datasets-server.huggingface.co/filter" | |
| DATASET = "simon3000/genshin-voice" | |
| MAX_RETRIES = 3 | |
| def _sanitize_name(name: str) -> str: | |
| return name.lower().replace(" ", "_").replace(":", "").replace("'", "") | |
| def _encode_voice_state(model, ref_path, state_path): | |
| """Encode a single voice state from reference WAV.""" | |
| from pocket_tts.models.tts_model import export_model_state | |
| t0 = time.time() | |
| state = model.get_state_for_audio_prompt(str(ref_path)) | |
| export_model_state(state, str(state_path)) | |
| return time.time() - t0 | |
| def _fetch_best_clip(speaker: str) -> tuple: | |
| """Fetch the longest English clip for a speaker via HF Datasets API. | |
| Returns (audio_bytes, transcription, duration) or (None, None, 0) on failure. | |
| """ | |
| # Escape single quotes in speaker name | |
| safe_speaker = speaker.replace("'", "\\'") | |
| params = { | |
| "dataset": DATASET, | |
| "config": "default", | |
| "split": "train", | |
| "where": f"speaker='{safe_speaker}' AND language='English(US)'", | |
| "length": 100, | |
| } | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| r = requests.get(HF_FILTER_URL, params=params, timeout=120) | |
| if r.status_code != 200: | |
| if attempt < MAX_RETRIES - 1: | |
| time.sleep(2) | |
| continue | |
| return None, None, 0 | |
| data = r.json() | |
| rows = data.get("rows", []) | |
| if not rows: | |
| return None, None, 0 | |
| # Find longest transcription (best reference clip) | |
| best = max(rows, key=lambda row: len(row["row"].get("transcription", "") or "")) | |
| trans = best["row"].get("transcription", "") | |
| audio_list = best["row"].get("audio", []) | |
| if not audio_list or not isinstance(audio_list, list): | |
| return None, None, 0 | |
| audio_src = audio_list[0].get("src", "") | |
| if not audio_src: | |
| return None, None, 0 | |
| # Download audio | |
| r2 = requests.get(audio_src, timeout=60) | |
| if r2.status_code != 200: | |
| if attempt < MAX_RETRIES - 1: | |
| time.sleep(2) | |
| continue | |
| return None, None, 0 | |
| audio_data, sr = sf.read(io.BytesIO(r2.content)) | |
| duration = len(audio_data) / sr | |
| return r2.content, trans, duration | |
| except Exception as e: | |
| if attempt < MAX_RETRIES - 1: | |
| time.sleep(2) | |
| continue | |
| print(f"ERROR: {e}") | |
| return None, None, 0 | |
| return None, None, 0 | |
| def main(): | |
| print("=" * 60) | |
| print("Genshin Voice Setup -- Extract & Encode Voice References") | |
| print("=" * 60) | |
| voices_dir = PROJECT_ROOT / "data" / "genshin_voices" | |
| catalog_path = PROJECT_ROOT / "data" / "genshin_voice_catalog.json" | |
| voices_dir.mkdir(parents=True, exist_ok=True) | |
| # Load Pocket TTS model upfront | |
| print("\n[1/3] Loading Pocket TTS model...") | |
| model = None | |
| try: | |
| from pocket_tts import TTSModel | |
| model = TTSModel.load_model() | |
| print(f" Model loaded (has_voice_cloning={model.has_voice_cloning})") | |
| if not model.has_voice_cloning: | |
| print(" WARNING: Voice cloning not available, will extract refs only") | |
| model = None | |
| except ImportError: | |
| print(" WARNING: pocket-tts not installed, will extract refs only") | |
| # Check which characters are fully done | |
| already_done = set() | |
| has_ref_only = set() | |
| for char_name in GENSHIN_CHARACTERS: | |
| char_id = _sanitize_name(char_name) | |
| state_path = voices_dir / char_id / "voice_state.safetensors" | |
| ref_path = voices_dir / char_id / "reference.wav" | |
| if state_path.exists(): | |
| already_done.add(char_name) | |
| elif ref_path.exists(): | |
| has_ref_only.add(char_name) | |
| if already_done: | |
| print(f"\n {len(already_done)} characters fully done (skipping)") | |
| # Encode any characters that have reference.wav but no voice_state | |
| if has_ref_only and model: | |
| print(f"\n Encoding {len(has_ref_only)} existing references...") | |
| for char_name in sorted(has_ref_only): | |
| char_id = _sanitize_name(char_name) | |
| ref_path = voices_dir / char_id / "reference.wav" | |
| state_path = voices_dir / char_id / "voice_state.safetensors" | |
| try: | |
| elapsed = _encode_voice_state(model, ref_path, state_path) | |
| already_done.add(char_name) | |
| print(f" {char_name}: encoded in {elapsed:.1f}s") | |
| except Exception as e: | |
| print(f" {char_name}: FAILED - {e}") | |
| # Determine what still needs extraction | |
| targets = [name for name in GENSHIN_CHARACTERS if name not in already_done] | |
| if not targets: | |
| print("\n All characters done!") | |
| _save_catalog(catalog_path, {}, voices_dir) | |
| return | |
| print(f"\n {len(targets)} characters still need extraction") | |
| # Fetch and encode each character via HF API | |
| print("\n[2/3] Fetching clips via HuggingFace Datasets API...") | |
| extracted = {} | |
| total_encoded = len(already_done) | |
| failed = [] | |
| for i, char_name in enumerate(targets): | |
| char_id = _sanitize_name(char_name) | |
| char_dir = voices_dir / char_id | |
| char_dir.mkdir(parents=True, exist_ok=True) | |
| ref_path = char_dir / "reference.wav" | |
| state_path = char_dir / "voice_state.safetensors" | |
| print(f" [{i+1}/{len(targets)}] {char_name}...", end=" ", flush=True) | |
| audio_bytes, trans, duration = _fetch_best_clip(char_name) | |
| if audio_bytes is None: | |
| print("NOT FOUND") | |
| failed.append(char_name) | |
| continue | |
| # Save reference WAV | |
| audio_data, sr = sf.read(io.BytesIO(audio_bytes)) | |
| samples = audio_data.astype(np.float32) | |
| samples_int16 = np.clip(samples * 32767, -32768, 32767).astype(np.int16) | |
| scipy.io.wavfile.write(str(ref_path), sr, samples_int16) | |
| # Encode voice state immediately | |
| if model: | |
| try: | |
| elapsed = _encode_voice_state(model, ref_path, state_path) | |
| total_encoded += 1 | |
| print(f"{duration:.1f}s ref -> encoded in {elapsed:.1f}s [{total_encoded}/{len(GENSHIN_CHARACTERS)}]") | |
| except Exception as e: | |
| print(f"ref saved ({duration:.1f}s), encode FAILED - {e}") | |
| else: | |
| print(f"ref saved ({duration:.1f}s)") | |
| extracted[char_name] = { | |
| "path": str(ref_path), | |
| "duration": duration, | |
| "transcript": trans, | |
| } | |
| # Small delay to avoid rate limiting | |
| time.sleep(0.5) | |
| print(f"\n\n[3/3] Summary") | |
| print(f" Extracted: {len(extracted)} new clips") | |
| print(f" Total encoded: {total_encoded}/{len(GENSHIN_CHARACTERS)}") | |
| if failed: | |
| print(f" Failed ({len(failed)}): {', '.join(failed)}") | |
| _save_catalog(catalog_path, extracted, voices_dir) | |
| print(f"\nDone! {total_encoded} total voices ready") | |
| def _save_catalog(catalog_path: Path, extracted: dict, voices_dir: Path): | |
| """Save voice catalog JSON.""" | |
| catalog = [] | |
| for char_name, info in GENSHIN_CHARACTERS.items(): | |
| char_id = _sanitize_name(char_name) | |
| char_dir = voices_dir / char_id | |
| ref_path = char_dir / "reference.wav" | |
| state_path = char_dir / "voice_state.safetensors" | |
| entry = { | |
| "id": f"genshin_{char_id}", | |
| "display_name": char_name, | |
| "source_anime": "Genshin Impact", | |
| "gender": info["gender"], | |
| "age_category": info["age"], | |
| "vocal_traits": info["traits"], | |
| "pitch_range": "mid", | |
| "clip_count": 500, | |
| "reference_path": str(ref_path) if ref_path.exists() else None, | |
| "voice_state_path": str(state_path) if state_path.exists() else None, | |
| "preset_voice": None, | |
| "archetype": info.get("archetype", ""), | |
| } | |
| if char_name in extracted: | |
| entry["reference_transcript"] = extracted[char_name].get("transcript", "") | |
| entry["reference_duration"] = extracted[char_name].get("duration", 0) | |
| catalog.append(entry) | |
| with open(catalog_path, "w", encoding="utf-8") as f: | |
| json.dump(catalog, f, indent=2, ensure_ascii=False) | |
| ready = sum(1 for e in catalog if e["voice_state_path"]) | |
| ref_only = sum(1 for e in catalog if e["reference_path"] and not e["voice_state_path"]) | |
| print(f" Catalog: {ready} ready (cloned), {ref_only} ref only, {len(catalog)} total") | |
| if __name__ == "__main__": | |
| main() | |