futurespyhi commited on
Commit
a4c1ace
·
1 Parent(s): 57403ec

Add exception handling to YuEGP generation loop

Browse files

- Add try-catch around generation loop to capture actual errors
- Provide detailed error messages and stack traces for debugging
- Handle partial output recovery when possible
- Prevent undefined raw_output variable error

Files changed (1) hide show
  1. YuEGP/inference/infer.py +20 -8
YuEGP/inference/infer.py CHANGED
@@ -236,10 +236,11 @@ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
236
  # Format text prompt
237
  run_n_segments = min(args.run_n_segments + 1, len(lyrics))
238
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
239
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
240
- guidance_scale = 1.5 if i <= 1 else 1.2
241
- if i == 0:
242
- continue
 
243
  if i == 1:
244
  if args.use_dual_tracks_prompt or args.use_audio_prompt:
245
  if args.use_dual_tracks_prompt:
@@ -298,10 +299,21 @@ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
298
  if output_seq[0][-1].item() != mmtokenizer.eoa:
299
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
300
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
301
- if i > 1:
302
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
303
- else:
304
- raw_output = output_seq
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  # save raw output and check sanity
307
  ids = raw_output[0].cpu().numpy()
 
236
  # Format text prompt
237
  run_n_segments = min(args.run_n_segments + 1, len(lyrics))
238
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
239
+ try:
240
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
241
+ guidance_scale = 1.5 if i <= 1 else 1.2
242
+ if i == 0:
243
+ continue
244
  if i == 1:
245
  if args.use_dual_tracks_prompt or args.use_audio_prompt:
246
  if args.use_dual_tracks_prompt:
 
299
  if output_seq[0][-1].item() != mmtokenizer.eoa:
300
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
301
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
302
+ if i > 1:
303
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
304
+ else:
305
+ raw_output = output_seq
306
+ except Exception as e:
307
+ print(f"❌ Error during generation for segment {i}: {e}")
308
+ print(f"Section text: {section_text}")
309
+ import traceback
310
+ traceback.print_exc()
311
+ # If this is the first successful generation, we still have something
312
+ if i > 1 and 'raw_output' in locals():
313
+ print("⚠️ Using partial output from previous segments")
314
+ break
315
+ else:
316
+ raise RuntimeError(f"Generation failed at segment {i}: {e}")
317
 
318
  # save raw output and check sanity
319
  ids = raw_output[0].cpu().numpy()