Swagcrew commited on
Commit
e2f85c6
·
verified ·
1 Parent(s): 14700d3

Upload gen_samples.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gen_samples.py +60 -119
gen_samples.py CHANGED
@@ -1,176 +1,117 @@
1
  #!/usr/bin/env python3
2
- """Generate voice clone samples from all quantized Fish Speech S2 Pro models."""
3
- import os, sys, json, time, gc, traceback, subprocess
4
  import torch
 
 
5
 
6
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
 
 
7
  DEVICE = "cuda"
8
  DTYPE = torch.bfloat16
9
  REF_TEXT = "Let me get this straight. You think that your client, one of the wealthiest most powerful men in the world, is secretly a vigilante who spends his nights beating criminals to a pulp with his bare hands. And your plan is to blackmail this person."
10
  GEN_TEXT = "Every man's life ends the same way. It is only the details of how he lived that distinguish one man from another."
11
  OUT = "/tmp/samples"
12
- MODEL_DIR = "/tmp/models"
13
-
14
- print("=== Fish Speech Voice Clone Sample Generator ===")
15
- print(f"GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
16
 
17
  os.makedirs(OUT, exist_ok=True)
18
- os.makedirs(MODEL_DIR, exist_ok=True)
19
 
20
- # Setup fish-speech
21
- sys.path.insert(0, "/app/fish-speech")
 
22
 
23
- from fish_speech.models.text2semantic.inference import init_model, load_codec_model, generate, decode_one_token_ar
24
- from fish_speech.conversation import Conversation, Message
25
- from fish_speech.content_sequence import TextPart, VQPart
26
- import torchaudio, soundfile as sf
27
-
28
- def load_ref_audio(ref_path):
29
- wav, sr = torchaudio.load(ref_path)
30
- if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True)
31
- if sr != 44100: wav = torchaudio.functional.resample(wav, sr, 44100)
32
- return wav
33
-
34
- def encode_ref(codec, wav):
35
- """Encode reference audio outside inference mode."""
36
- import torch
37
- # Exit inference mode temporarily for codec (it needs autograd for some ops)
38
- wav_clone = wav.clone().detach().to(DEVICE).requires_grad_(False)
39
- # Use torch.no_grad context instead of inference_mode
40
- with torch.no_grad():
41
- with torch.amp.autocast(device_type="cuda", dtype=DTYPE):
42
- enc = codec.encode(wav_clone.unsqueeze(0))
43
- tokens = enc[0] if isinstance(enc, tuple) else enc
44
- # Squeeze batch dim: (1, num_codebooks, T) -> (num_codebooks, T)
45
- if tokens.ndim == 3 and tokens.shape[0] == 1:
46
- tokens = tokens.squeeze(0)
47
- return tokens.detach()
48
-
49
- def generate_clone(model, codec, ref_tokens, ref_text, gen_text, out_path):
50
- """Generate voice clone using the Conversation API correctly."""
51
- conv = Conversation()
52
- conv.append(Message(role="user", parts=[
53
- VQPart(codes=ref_tokens.cpu()),
54
- TextPart(text=ref_text)
55
- ]))
56
- conv.append(Message(role="assistant", parts=[TextPart(text=gen_text)]))
57
-
58
- nc = model.config.num_codebooks
59
- tokenizer = model.tokenizer
60
- result = conv.encode_for_inference(tokenizer, nc)
61
- # encode_for_inference returns (prompt_tensor,) or prompt_tensor
62
- if isinstance(result, tuple):
63
- prompt = result[0].to(DEVICE)
64
- else:
65
- prompt = result.to(DEVICE)
66
-
67
- cd = 1 + nc
68
- am = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.bool, device=DEVICE)
69
- ap = torch.zeros(1, cd, prompt.shape[-1], dtype=torch.long, device=DEVICE)
70
-
71
- if not getattr(model, '_cd', False):
72
- model.setup_caches(1, model.config.max_seq_len, dtype=DTYPE)
73
- model._cd = True
74
-
75
- with torch.autocast(device_type="cuda", dtype=DTYPE):
76
- r = generate(model=model, prompt=prompt, max_new_tokens=1024,
77
- audio_masks=am, audio_parts=ap, temperature=0.7, top_p=0.7, top_k=30,
78
- decode_one_token=decode_one_token_ar)
79
-
80
- codes = r[0:1,:,:].unsqueeze(0)
81
- with torch.autocast(device_type="cuda", dtype=DTYPE):
82
- audio = codec.decode(codes.to(DEVICE))
83
-
84
- np_audio = audio.squeeze().cpu().float().numpy()
85
- sr = getattr(codec, 'sample_rate', 44100)
86
- sf.write(out_path, np_audio, sr)
87
- dur = len(np_audio) / sr
88
- print(f" Saved {out_path} ({dur:.1f}s)")
89
- return True
90
-
91
- # --- Models to test ---
92
- MODELS = {
93
- "baseline_bf16": {"source": "fishaudio/s2-pro", "quant": None},
94
- "fp8": {"source": "drbaph/s2-pro-fp8", "quant": None},
95
- }
96
 
97
  def main():
98
- # Load reference audio
99
- ref_path = "/app/reference/morgan_ref.wav"
100
- if not os.path.exists(ref_path):
101
- ref_path = "/tmp/reference/morgan_ref.wav"
102
- print(f"\n[1] Loading reference audio: {ref_path}")
103
- ref_wav = load_ref_audio(ref_path)
104
 
105
- results = {}
106
 
107
- for name, cfg in MODELS.items():
108
  print(f"\n{'='*60}")
109
- print(f" {name.upper()}")
110
  print(f"{'='*60}")
111
 
112
- model_id = cfg["source"]
113
- local_dir = f"{MODEL_DIR}/{name}"
114
-
115
- # Download if needed
116
  if not os.path.exists(f"{local_dir}/config.json"):
117
  print(f" Downloading {model_id}...")
118
  from huggingface_hub import snapshot_download
119
  snapshot_download(model_id, local_dir=local_dir, token=os.environ.get("HF_TOKEN"))
120
 
121
- # Load model
122
  print(f" Loading model...")
123
- model, _ = init_model(local_dir, DEVICE, DTYPE, compile=False)
124
  codec = load_codec_model(f"{local_dir}/codec.pth", DEVICE, DTYPE)
125
 
126
- # Encode reference
127
- ref_tokens = encode_ref(codec, ref_wav)
 
 
 
128
 
129
- # Generate
130
  out_path = f"{OUT}/fish_{name}_morgan_clone.wav"
 
131
  try:
132
- ok = generate_clone(model, codec, ref_tokens, REF_TEXT, GEN_TEXT, out_path)
133
- results[name] = {"ok": ok, "file": out_path}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  except Exception as e:
135
- print(f" FAILED: {e}")
136
  traceback.print_exc()
137
- results[name] = {"ok": False, "error": str(e)}
138
 
139
  del model, codec
140
  gc.collect()
141
  torch.cuda.empty_cache()
142
 
143
- # Also generate from GGUF models using s2.cpp if available
144
- # For now, just upload what we have
145
-
146
- # Summary
147
  print(f"\n{'='*60}")
148
- print(" RESULTS")
149
  print(f"{'='*60}")
150
- for name, r in results.items():
151
- status = "✅" if r["ok"] else "❌"
152
- print(f" {status} {name}: {r.get('file', r.get('error',''))}")
153
-
154
- # Upload to Hub
155
- print("\n[Final] Uploading samples to Hub...")
156
  try:
157
  from huggingface_hub import HfApi
158
  api = HfApi()
159
  repo = "Swagcrew/fish-speech-s2-quantized"
160
  for fn in os.listdir(OUT):
161
  if fn.endswith(".wav"):
162
- fpath = os.path.join(OUT, fn)
163
  api.upload_file(
164
- path_or_fileobj=fpath,
165
  path_in_repo=f"samples/{fn}",
166
  repo_id=repo,
167
  repo_type="model"
168
  )
169
  print(f" Uploaded samples/{fn}")
170
- print(f"\n All at https://huggingface.co/{repo}/tree/main/samples")
171
  except Exception as e:
172
  print(f" Upload error: {e}")
173
- traceback.print_exc()
174
 
175
  print("\nDONE!")
176
 
 
1
  #!/usr/bin/env python3
2
+ """Generate voice clone samples using fish-speech's generate_long API."""
3
+ import os, sys, json, time, gc, traceback
4
  import torch
5
+ import torchaudio
6
+ import soundfile as sf
7
 
8
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
9
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
10
+ sys.path.insert(0, "/app/fish-speech")
11
+
12
  DEVICE = "cuda"
13
  DTYPE = torch.bfloat16
14
  REF_TEXT = "Let me get this straight. You think that your client, one of the wealthiest most powerful men in the world, is secretly a vigilante who spends his nights beating criminals to a pulp with his bare hands. And your plan is to blackmail this person."
15
  GEN_TEXT = "Every man's life ends the same way. It is only the details of how he lived that distinguish one man from another."
16
  OUT = "/tmp/samples"
 
 
 
 
17
 
18
  os.makedirs(OUT, exist_ok=True)
 
19
 
20
+ from fish_speech.models.text2semantic.inference import (
21
+ init_model, load_codec_model, encode_audio, generate_long
22
+ )
23
 
24
+ MODELS = [
25
+ ("baseline_bf16", "fishaudio/s2-pro"),
26
+ ("fp8", "drbaph/s2-pro-fp8"),
27
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def main():
30
+ print(f"=== Fish Speech Voice Clone Sample Generator ===")
31
+ print(f"GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
 
 
 
 
32
 
33
+ ref_path = "/app/reference/morgan_ref.wav"
34
 
35
+ for name, model_id in MODELS:
36
  print(f"\n{'='*60}")
37
+ print(f" {name.upper()} ({model_id})")
38
  print(f"{'='*60}")
39
 
40
+ local_dir = f"/tmp/models/{name}"
 
 
 
41
  if not os.path.exists(f"{local_dir}/config.json"):
42
  print(f" Downloading {model_id}...")
43
  from huggingface_hub import snapshot_download
44
  snapshot_download(model_id, local_dir=local_dir, token=os.environ.get("HF_TOKEN"))
45
 
 
46
  print(f" Loading model...")
47
+ model, decode_fn = init_model(local_dir, DEVICE, DTYPE, compile=False)
48
  codec = load_codec_model(f"{local_dir}/codec.pth", DEVICE, DTYPE)
49
 
50
+ with torch.device(DEVICE):
51
+ model.setup_caches(max_batch_size=1, max_seq_len=model.config.max_seq_len, dtype=DTYPE)
52
+
53
+ print(f" Encoding reference audio...")
54
+ prompt_tokens = encode_audio(ref_path, codec, DEVICE).cpu()
55
 
56
+ print(f" Generating voice clone...")
57
  out_path = f"{OUT}/fish_{name}_morgan_clone.wav"
58
+
59
  try:
60
+ for response in generate_long(
61
+ model=model,
62
+ device=DEVICE,
63
+ decode_one_token=decode_fn,
64
+ text=GEN_TEXT,
65
+ max_new_tokens=1024,
66
+ top_p=0.7,
67
+ top_k=30,
68
+ temperature=0.7,
69
+ repetition_penalty=1.1,
70
+ compile=False,
71
+ iterative_prompt=False,
72
+ chunk_length=0,
73
+ prompt_text=REF_TEXT,
74
+ prompt_tokens=prompt_tokens,
75
+ ):
76
+ if response.action == "sample":
77
+ codes = response.codes
78
+ with torch.no_grad():
79
+ with torch.amp.autocast(device_type="cuda", dtype=DTYPE):
80
+ audio = codec.decode(codes.unsqueeze(0).to(DEVICE))
81
+ np_audio = audio.squeeze().cpu().float().numpy()
82
+ sr = getattr(codec, 'sample_rate', 44100)
83
+ sf.write(out_path, np_audio, sr)
84
+ dur = len(np_audio) / sr
85
+ print(f" ✅ Saved {out_path} ({dur:.1f}s)")
86
+
87
  except Exception as e:
88
+ print(f" FAILED: {e}")
89
  traceback.print_exc()
 
90
 
91
  del model, codec
92
  gc.collect()
93
  torch.cuda.empty_cache()
94
 
95
+ # Upload
 
 
 
96
  print(f"\n{'='*60}")
97
+ print(f" UPLOADING TO HUB")
98
  print(f"{'='*60}")
 
 
 
 
 
 
99
  try:
100
  from huggingface_hub import HfApi
101
  api = HfApi()
102
  repo = "Swagcrew/fish-speech-s2-quantized"
103
  for fn in os.listdir(OUT):
104
  if fn.endswith(".wav"):
 
105
  api.upload_file(
106
+ path_or_fileobj=os.path.join(OUT, fn),
107
  path_in_repo=f"samples/{fn}",
108
  repo_id=repo,
109
  repo_type="model"
110
  )
111
  print(f" Uploaded samples/{fn}")
112
+ print(f"\n https://huggingface.co/{repo}/tree/main/samples")
113
  except Exception as e:
114
  print(f" Upload error: {e}")
 
115
 
116
  print("\nDONE!")
117