Swagcrew commited on
Commit
8c17c76
·
verified ·
1 Parent(s): 9f00efa

Upload gen_samples.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gen_samples.py +170 -0
gen_samples.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate voice clone samples from all quantized Fish Speech S2 Pro models."""
3
+ import os, sys, json, time, gc, traceback, subprocess
4
+ import torch
5
+
6
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
7
+ DEVICE = "cuda"
8
+ DTYPE = torch.bfloat16
9
+ REF_TEXT = "Let me get this straight. You think that your client, one of the wealthiest most powerful men in the world, is secretly a vigilante who spends his nights beating criminals to a pulp with his bare hands. And your plan is to blackmail this person."
10
+ GEN_TEXT = "Every man's life ends the same way. It is only the details of how he lived that distinguish one man from another."
11
+ OUT = "/tmp/samples"
12
+ MODEL_DIR = "/tmp/models"
13
+
14
+ print("=== Fish Speech Voice Clone Sample Generator ===")
15
+ print(f"GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
16
+
17
+ os.makedirs(OUT, exist_ok=True)
18
+ os.makedirs(MODEL_DIR, exist_ok=True)
19
+
20
+ # Setup fish-speech
21
+ sys.path.insert(0, "/app/fish-speech")
22
+
23
+ from fish_speech.models.text2semantic.inference import init_model, load_codec_model, generate, decode_one_token_ar
24
+ from fish_speech.conversation import Conversation, Message
25
+ from fish_speech.content_sequence import TextPart, VQPart
26
+ import torchaudio, soundfile as sf
27
+
28
+ def load_ref_audio(ref_path):
29
+ wav, sr = torchaudio.load(ref_path)
30
+ if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True)
31
+ if sr != 44100: wav = torchaudio.functional.resample(wav, sr, 44100)
32
+ return wav
33
+
34
+ def encode_ref(codec, wav):
35
+ wav = wav.to(DEVICE)
36
+ with torch.autocast(device_type="cuda", dtype=DTYPE):
37
+ enc = codec.encode(wav.unsqueeze(0))
38
+ tokens = enc[0] if isinstance(enc, tuple) else enc
39
+ return tokens.cpu().numpy()
40
+
41
+ def generate_clone(model, codec, ref_tokens, ref_text, gen_text, out_path):
42
+ """Generate voice clone using the Conversation API correctly."""
43
+ conv = Conversation()
44
+ conv.append(Message(role="user", parts=[
45
+ VQPart(codes=ref_tokens),
46
+ TextPart(text=ref_text)
47
+ ]))
48
+ conv.append(Message(role="assistant", parts=[TextPart(text=gen_text)]))
49
+
50
+ nc = model.config.num_codebooks
51
+ tokenizer = model.tokenizer
52
+ result = conv.encode_for_inference(tokenizer, nc)
53
+ # encode_for_inference returns (prompt_tensor,) or prompt_tensor
54
+ if isinstance(result, tuple):
55
+ prompt = result[0]
56
+ else:
57
+ prompt = result
58
+
59
+ cd = 1 + nc
60
+ am = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.bool, device=DEVICE)
61
+ ap = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.long, device=DEVICE)
62
+
63
+ if not getattr(model, '_cd', False):
64
+ model.setup_caches(1, model.config.max_seq_len, dtype=DTYPE)
65
+ model._cd = True
66
+
67
+ with torch.autocast(device_type="cuda", dtype=DTYPE):
68
+ r = generate(model=model, prompt=prompt, max_new_tokens=1024,
69
+ audio_masks=am, audio_parts=ap, temperature=0.7, top_p=0.7, top_k=30,
70
+ decode_one_token=decode_one_token_ar)
71
+
72
+ codes = r[0:1,:,:].unsqueeze(0)
73
+ with torch.autocast(device_type="cuda", dtype=DTYPE):
74
+ audio = codec.decode(codes.to(DEVICE))
75
+
76
+ np_audio = audio.squeeze().cpu().float().numpy()
77
+ sr = getattr(codec, 'sample_rate', 44100)
78
+ sf.write(out_path, np_audio, sr)
79
+ dur = len(np_audio) / sr
80
+ print(f" Saved {out_path} ({dur:.1f}s)")
81
+ return True
82
+
83
+ # --- Models to test ---
84
+ MODELS = {
85
+ "baseline_bf16": {"source": "fishaudio/s2-pro", "quant": None},
86
+ "fp8": {"source": "drbaph/s2-pro-fp8", "quant": None},
87
+ }
88
+
89
+ def main():
90
+ # Load reference audio
91
+ ref_path = "/app/reference/morgan_ref.wav"
92
+ if not os.path.exists(ref_path):
93
+ ref_path = "/tmp/reference/morgan_ref.wav"
94
+ print(f"\n[1] Loading reference audio: {ref_path}")
95
+ ref_wav = load_ref_audio(ref_path)
96
+
97
+ results = {}
98
+
99
+ for name, cfg in MODELS.items():
100
+ print(f"\n{'='*60}")
101
+ print(f" {name.upper()}")
102
+ print(f"{'='*60}")
103
+
104
+ model_id = cfg["source"]
105
+ local_dir = f"{MODEL_DIR}/{name}"
106
+
107
+ # Download if needed
108
+ if not os.path.exists(f"{local_dir}/config.json"):
109
+ print(f" Downloading {model_id}...")
110
+ from huggingface_hub import snapshot_download
111
+ snapshot_download(model_id, local_dir=local_dir, token=os.environ.get("HF_TOKEN"))
112
+
113
+ # Load model
114
+ print(f" Loading model...")
115
+ model, _ = init_model(local_dir, DEVICE, DTYPE, compile=False)
116
+ codec = load_codec_model(f"{local_dir}/codec.pth", DEVICE, DTYPE)
117
+
118
+ # Encode reference
119
+ ref_tokens = encode_ref(codec, ref_wav)
120
+
121
+ # Generate
122
+ out_path = f"{OUT}/fish_{name}_morgan_clone.wav"
123
+ try:
124
+ ok = generate_clone(model, codec, ref_tokens, REF_TEXT, GEN_TEXT, out_path)
125
+ results[name] = {"ok": ok, "file": out_path}
126
+ except Exception as e:
127
+ print(f" FAILED: {e}")
128
+ traceback.print_exc()
129
+ results[name] = {"ok": False, "error": str(e)}
130
+
131
+ del model, codec
132
+ gc.collect()
133
+ torch.cuda.empty_cache()
134
+
135
+ # Also generate from GGUF models using s2.cpp if available
136
+ # For now, just upload what we have
137
+
138
+ # Summary
139
+ print(f"\n{'='*60}")
140
+ print(" RESULTS")
141
+ print(f"{'='*60}")
142
+ for name, r in results.items():
143
+ status = "✅" if r["ok"] else "❌"
144
+ print(f" {status} {name}: {r.get('file', r.get('error',''))}")
145
+
146
+ # Upload to Hub
147
+ print("\n[Final] Uploading samples to Hub...")
148
+ try:
149
+ from huggingface_hub import HfApi
150
+ api = HfApi()
151
+ repo = "Swagcrew/fish-speech-s2-quantized"
152
+ for fn in os.listdir(OUT):
153
+ if fn.endswith(".wav"):
154
+ fpath = os.path.join(OUT, fn)
155
+ api.upload_file(
156
+ path_or_fileobj=fpath,
157
+ path_in_repo=f"samples/{fn}",
158
+ repo_id=repo,
159
+ repo_type="model"
160
+ )
161
+ print(f" Uploaded samples/{fn}")
162
+ print(f"\n All at https://huggingface.co/{repo}/tree/main/samples")
163
+ except Exception as e:
164
+ print(f" Upload error: {e}")
165
+ traceback.print_exc()
166
+
167
+ print("\nDONE!")
168
+
169
+ if __name__ == "__main__":
170
+ main()