futurespyhi commited on
Commit
b81feb7
·
1 Parent(s): 0198d75

Fix syntax errors in infer.py

Browse files

- Remove incomplete try-except blocks that caused syntax errors
- Fix indentation issues in generation loop
- Clean up orphaned except clauses
- Restore proper code structure while keeping debug information

Files changed (1) hide show
  1. YuEGP/inference/infer.py +8 -20
YuEGP/inference/infer.py CHANGED
@@ -236,11 +236,10 @@ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
236
  # Format text prompt
237
  run_n_segments = min(args.run_n_segments + 1, len(lyrics))
238
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
239
- try:
240
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
241
- guidance_scale = 1.5 if i <= 1 else 1.2
242
- if i == 0:
243
- continue
244
  if i == 1:
245
  if args.use_dual_tracks_prompt or args.use_audio_prompt:
246
  if args.use_dual_tracks_prompt:
@@ -299,21 +298,10 @@ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
299
  if output_seq[0][-1].item() != mmtokenizer.eoa:
300
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
301
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
302
- if i > 1:
303
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
304
- else:
305
- raw_output = output_seq
306
- except Exception as e:
307
- print(f"❌ Error during generation for segment {i}: {e}")
308
- print(f"Section text: {section_text}")
309
- import traceback
310
- traceback.print_exc()
311
- # If this is the first successful generation, we still have something
312
- if i > 1 and 'raw_output' in locals():
313
- print("⚠️ Using partial output from previous segments")
314
- break
315
- else:
316
- raise RuntimeError(f"Generation failed at segment {i}: {e}")
317
 
318
  # save raw output and check sanity
319
  # Check if raw_output was defined (debug for generation issues)
 
236
  # Format text prompt
237
  run_n_segments = min(args.run_n_segments + 1, len(lyrics))
238
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
239
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
240
+ guidance_scale = 1.5 if i <= 1 else 1.2
241
+ if i == 0:
242
+ continue
 
243
  if i == 1:
244
  if args.use_dual_tracks_prompt or args.use_audio_prompt:
245
  if args.use_dual_tracks_prompt:
 
298
  if output_seq[0][-1].item() != mmtokenizer.eoa:
299
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
300
  output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
301
+ if i > 1:
302
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
303
+ else:
304
+ raw_output = output_seq
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  # save raw output and check sanity
307
  # Check if raw_output was defined (debug for generation issues)