Spaces:
Running
Running
futurespyhi
commited on
Commit
·
a4c1ace
1
Parent(s):
57403ec
Add exception handling to YuEGP generation loop
Browse files- Add try-catch around generation loop to capture actual errors
- Provide detailed error messages and stack traces for debugging
- Handle partial output recovery when possible
- Prevent undefined raw_output variable error
- YuEGP/inference/infer.py +20 -8
YuEGP/inference/infer.py
CHANGED
|
@@ -236,10 +236,11 @@ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
|
| 236 |
# Format text prompt
|
| 237 |
run_n_segments = min(args.run_n_segments + 1, len(lyrics))
|
| 238 |
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
| 243 |
if i == 1:
|
| 244 |
if args.use_dual_tracks_prompt or args.use_audio_prompt:
|
| 245 |
if args.use_dual_tracks_prompt:
|
|
@@ -298,10 +299,21 @@ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
|
| 298 |
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
| 299 |
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
| 300 |
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
# save raw output and check sanity
|
| 307 |
ids = raw_output[0].cpu().numpy()
|
|
|
|
| 236 |
# Format text prompt
|
| 237 |
run_n_segments = min(args.run_n_segments + 1, len(lyrics))
|
| 238 |
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
| 239 |
+
try:
|
| 240 |
+
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
| 241 |
+
guidance_scale = 1.5 if i <= 1 else 1.2
|
| 242 |
+
if i == 0:
|
| 243 |
+
continue
|
| 244 |
if i == 1:
|
| 245 |
if args.use_dual_tracks_prompt or args.use_audio_prompt:
|
| 246 |
if args.use_dual_tracks_prompt:
|
|
|
|
| 299 |
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
| 300 |
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
| 301 |
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
| 302 |
+
if i > 1:
|
| 303 |
+
raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
|
| 304 |
+
else:
|
| 305 |
+
raw_output = output_seq
|
| 306 |
+
except Exception as e:
|
| 307 |
+
print(f"❌ Error during generation for segment {i}: {e}")
|
| 308 |
+
print(f"Section text: {section_text}")
|
| 309 |
+
import traceback
|
| 310 |
+
traceback.print_exc()
|
| 311 |
+
# If this is the first successful generation, we still have something
|
| 312 |
+
if i > 1 and 'raw_output' in locals():
|
| 313 |
+
print("⚠️ Using partial output from previous segments")
|
| 314 |
+
break
|
| 315 |
+
else:
|
| 316 |
+
raise RuntimeError(f"Generation failed at segment {i}: {e}")
|
| 317 |
|
| 318 |
# save raw output and check sanity
|
| 319 |
ids = raw_output[0].cpu().numpy()
|