Swagcrew commited on
Commit
3a38d65
·
verified ·
1 Parent(s): ecf804f

Upload gen_samples.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gen_samples.py +63 -59
gen_samples.py CHANGED
@@ -1,26 +1,17 @@
1
  #!/usr/bin/env python3
2
- """Generate voice clone samples using fish-speech's generate_long API."""
3
- import os, sys, json, time, gc, traceback
4
  import torch
5
- import torchaudio
6
- import soundfile as sf
7
 
8
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
9
  os.environ["HF_HOME"] = "/tmp/hf_cache"
10
  sys.path.insert(0, "/app/fish-speech")
11
 
12
- DEVICE = "cuda"
13
- DTYPE = torch.bfloat16
14
- REF_TEXT = "Let me get this straight. You think that your client, one of the wealthiest most powerful men in the world, is secretly a vigilante who spends his nights beating criminals to a pulp with his bare hands. And your plan is to blackmail this person."
15
  GEN_TEXT = "Every man's life ends the same way. It is only the details of how he lived that distinguish one man from another."
 
16
  OUT = "/tmp/samples"
17
-
18
  os.makedirs(OUT, exist_ok=True)
19
 
20
- from fish_speech.models.text2semantic.inference import (
21
- init_model, load_codec_model, encode_audio, generate_long
22
- )
23
-
24
  MODELS = [
25
  ("baseline_bf16", "fishaudio/s2-pro"),
26
  ("fp8", "drbaph/s2-pro-fp8"),
@@ -30,8 +21,6 @@ def main():
30
  print(f"=== Fish Speech Voice Clone Sample Generator ===")
31
  print(f"GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
32
 
33
- ref_path = "/app/reference/morgan_ref.wav"
34
-
35
  for name, model_id in MODELS:
36
  print(f"\n{'='*60}")
37
  print(f" {name.upper()} ({model_id})")
@@ -43,52 +32,67 @@ def main():
43
  from huggingface_hub import snapshot_download
44
  snapshot_download(model_id, local_dir=local_dir, token=os.environ.get("HF_TOKEN"))
45
 
46
- print(f" Loading model...")
47
- model, decode_fn = init_model(local_dir, DEVICE, DTYPE, compile=False)
48
- codec = load_codec_model(f"{local_dir}/codec.pth", DEVICE, DTYPE)
49
-
50
- with torch.device(DEVICE):
51
- model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=DTYPE)
52
-
53
- print(f" Encoding reference audio...")
54
- prompt_tokens = encode_audio(ref_path, codec, DEVICE).cpu()
55
-
56
- print(f" Generating voice clone...")
57
  out_path = f"{OUT}/fish_{name}_morgan_clone.wav"
58
 
59
- try:
60
- for response in generate_long(
61
- model=model,
62
- device=DEVICE,
63
- decode_one_token=decode_fn,
64
- text=GEN_TEXT,
65
- max_new_tokens=1024,
66
- top_p=0.7,
67
- top_k=30,
68
- temperature=0.7,
69
- repetition_penalty=1.1,
70
- compile=False,
71
- iterative_prompt=False,
72
- chunk_length=0,
73
- prompt_text=[REF_TEXT],
74
- prompt_tokens=[prompt_tokens],
75
- ):
76
- if response.action == "sample":
77
- codes = response.codes
78
- with torch.no_grad():
79
- with torch.amp.autocast(device_type="cuda", dtype=DTYPE):
80
- audio = codec.decode(codes.unsqueeze(0).to(DEVICE))
81
- np_audio = audio.squeeze().cpu().float().numpy()
82
- sr = getattr(codec, 'sample_rate', 44100)
83
- sf.write(out_path, np_audio, sr)
84
- dur = len(np_audio) / sr
85
- print(f" ✅ Saved {out_path} ({dur:.1f}s)")
86
-
87
- except Exception as e:
88
- print(f" ❌ FAILED: {e}")
89
- traceback.print_exc()
90
-
91
- del model, codec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  gc.collect()
93
  torch.cuda.empty_cache()
94
 
@@ -100,7 +104,7 @@ def main():
100
  from huggingface_hub import HfApi
101
  api = HfApi()
102
  repo = "Swagcrew/fish-speech-s2-quantized"
103
- for fn in os.listdir(OUT):
104
  if fn.endswith(".wav"):
105
  api.upload_file(
106
  path_or_fileobj=os.path.join(OUT, fn),
 
1
  #!/usr/bin/env python3
2
+ """Generate voice clone samples using fish-speech CLI."""
3
+ import os, sys, json, time, gc, traceback, subprocess
4
  import torch
 
 
5
 
6
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
7
  os.environ["HF_HOME"] = "/tmp/hf_cache"
8
  sys.path.insert(0, "/app/fish-speech")
9
 
 
 
 
10
  GEN_TEXT = "Every man's life ends the same way. It is only the details of how he lived that distinguish one man from another."
11
+ REF_TEXT = "Let me get this straight. You think that your client, one of the wealthiest most powerful men in the world, is secretly a vigilante who spends his nights beating criminals to a pulp with his bare hands. And your plan is to blackmail this person."
12
  OUT = "/tmp/samples"
 
13
  os.makedirs(OUT, exist_ok=True)
14
 
 
 
 
 
15
  MODELS = [
16
  ("baseline_bf16", "fishaudio/s2-pro"),
17
  ("fp8", "drbaph/s2-pro-fp8"),
 
21
  print(f"=== Fish Speech Voice Clone Sample Generator ===")
22
  print(f"GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
23
 
 
 
24
  for name, model_id in MODELS:
25
  print(f"\n{'='*60}")
26
  print(f" {name.upper()} ({model_id})")
 
32
  from huggingface_hub import snapshot_download
33
  snapshot_download(model_id, local_dir=local_dir, token=os.environ.get("HF_TOKEN"))
34
 
 
 
 
 
 
 
 
 
 
 
 
35
  out_path = f"{OUT}/fish_{name}_morgan_clone.wav"
36
 
37
+ # Step 1: Generate semantic tokens using the CLI
38
+ semantic_dir = f"{OUT}/{name}_semantic"
39
+ os.makedirs(semantic_dir, exist_ok=True)
40
+
41
+ cmd = [
42
+ sys.executable, "-m", "fish_speech.models.text2semantic.inference",
43
+ "--text", f"<|speaker:0|>{GEN_TEXT}",
44
+ "--prompt-audio", "/app/reference/morgan_ref.wav",
45
+ "--prompt-text", REF_TEXT,
46
+ "--checkpoint-path", local_dir,
47
+ "--output-dir", semantic_dir,
48
+ "--num-samples", "1",
49
+ "--max-new-tokens", "1024",
50
+ "--top-p", "0.7",
51
+ "--top-k", "30",
52
+ "--temperature", "0.7",
53
+ "--no-iterative-prompt",
54
+ "--chunk-length", "0",
55
+ "--device", "cuda",
56
+ ]
57
+
58
+ print(f" Generating semantic tokens...")
59
+ env = {**os.environ, "PYTHONPATH": "/app/fish-speech"}
60
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, env=env)
61
+
62
+ print(f" CLI stdout (last 500): {result.stdout[-500:]}")
63
+ if result.stderr:
64
+ print(f" CLI stderr (last 500): {result.stderr[-500:]}")
65
+
66
+ # Find generated .pt files
67
+ pt_files = [f for f in os.listdir(semantic_dir) if f.endswith('.pt')]
68
+ if not pt_files:
69
+ print(f" ❌ No .pt files generated")
70
+ continue
71
+
72
+ print(f" Generated {len(pt_files)} semantic files")
73
+
74
+ # Step 2: Decode semantic tokens to audio using codec
75
+ import torchaudio
76
+ import soundfile as sf
77
+ from fish_speech.models.text2semantic.inference import load_codec_model
78
+
79
+ codec = load_codec_model(f"{local_dir}/codec.pth", "cuda", torch.bfloat16)
80
+
81
+ for pt_file in pt_files:
82
+ codes = torch.load(os.path.join(semantic_dir, pt_file), map_location="cuda", weights_only=True)
83
+ print(f" Decoding {pt_file}, codes shape: {codes.shape}")
84
+
85
+ with torch.no_grad():
86
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
87
+ audio = codec.decode(codes.unsqueeze(0))
88
+
89
+ np_audio = audio.squeeze().cpu().float().numpy()
90
+ sr = getattr(codec, 'sample_rate', 44100)
91
+ sf.write(out_path, np_audio, sr)
92
+ dur = len(np_audio) / sr
93
+ print(f" ✅ Saved {out_path} ({dur:.1f}s)")
94
+
95
+ del codec
96
  gc.collect()
97
  torch.cuda.empty_cache()
98
 
 
104
  from huggingface_hub import HfApi
105
  api = HfApi()
106
  repo = "Swagcrew/fish-speech-s2-quantized"
107
+ for fn in sorted(os.listdir(OUT)):
108
  if fn.endswith(".wav"):
109
  api.upload_file(
110
  path_or_fileobj=os.path.join(OUT, fn),