Swagcrew commited on
Commit
60a491a
·
verified ·
1 Parent(s): 3a38d65

Upload gen_samples.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gen_samples.py +6 -32
gen_samples.py CHANGED
@@ -53,6 +53,7 @@ def main():
53
  "--no-iterative-prompt",
54
  "--chunk-length", "0",
55
  "--device", "cuda",
 
56
  ]
57
 
58
  print(f" Generating semantic tokens...")
@@ -63,38 +64,11 @@ def main():
63
  if result.stderr:
64
  print(f" CLI stderr (last 500): {result.stderr[-500:]}")
65
 
66
- # Find generated .pt files
67
- pt_files = [f for f in os.listdir(semantic_dir) if f.endswith('.pt')]
68
- if not pt_files:
69
- print(f" ❌ No .pt files generated")
70
- continue
71
-
72
- print(f" Generated {len(pt_files)} semantic files")
73
-
74
- # Step 2: Decode semantic tokens to audio using codec
75
- import torchaudio
76
- import soundfile as sf
77
- from fish_speech.models.text2semantic.inference import load_codec_model
78
-
79
- codec = load_codec_model(f"{local_dir}/codec.pth", "cuda", torch.bfloat16)
80
-
81
- for pt_file in pt_files:
82
- codes = torch.load(os.path.join(semantic_dir, pt_file), map_location="cuda", weights_only=True)
83
- print(f" Decoding {pt_file}, codes shape: {codes.shape}")
84
-
85
- with torch.no_grad():
86
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
87
- audio = codec.decode(codes.unsqueeze(0))
88
-
89
- np_audio = audio.squeeze().cpu().float().numpy()
90
- sr = getattr(codec, 'sample_rate', 44100)
91
- sf.write(out_path, np_audio, sr)
92
- dur = len(np_audio) / sr
93
- print(f" ✅ Saved {out_path} ({dur:.1f}s)")
94
-
95
- del codec
96
- gc.collect()
97
- torch.cuda.empty_cache()
98
 
99
  # Upload
100
  print(f"\n{'='*60}")
 
53
  "--no-iterative-prompt",
54
  "--chunk-length", "0",
55
  "--device", "cuda",
56
+ "--output", out_path,
57
  ]
58
 
59
  print(f" Generating semantic tokens...")
 
64
  if result.stderr:
65
  print(f" CLI stderr (last 500): {result.stderr[-500:]}")
66
 
67
+ if os.path.exists(out_path):
68
+ sz = os.path.getsize(out_path)
69
+ print(f" ✅ Saved {out_path} ({sz/1024:.0f}KB)")
70
+ else:
71
+ print(f" ❌ Output not found: {out_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Upload
74
  print(f"\n{'='*60}")