futurespyhi commited on
Commit
6d4b73b
·
1 Parent(s): 4b92fef

1.modify split_lyrics regex 2.define raw_output outside loop

Browse files
Files changed (1) hide show
  1. YuE/inference/infer.py +4 -1
YuE/inference/infer.py CHANGED
@@ -131,7 +131,7 @@ def encode_audio(codec_model, audio_prompt, device, target_bw=0.5):
131
  return raw_codes
132
 
133
  def split_lyrics(lyrics):
134
- pattern = r"\[(\w+)\](.*?)(?=\[|\Z)"
135
  segments = re.findall(pattern, lyrics, re.DOTALL)
136
  structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
137
  return structured_lyrics
@@ -162,6 +162,7 @@ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
162
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
163
  # Format text prompt
164
  run_n_segments = min(args.run_n_segments+1, len(lyrics))
 
165
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Stage1 inference...")):
166
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
167
  guidance_scale = 1.5 if i <=1 else 1.2
@@ -224,6 +225,8 @@ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Stage1 inference
224
  raw_output = output_seq
225
 
226
  # save raw output and check sanity
 
 
227
  ids = raw_output[0].cpu().numpy()
228
  soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
229
  eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
 
131
  return raw_codes
132
 
133
  def split_lyrics(lyrics):
134
+ pattern = r"\[([^]]+)\](.*?)(?=\[|\Z)"
135
  segments = re.findall(pattern, lyrics, re.DOTALL)
136
  structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
137
  return structured_lyrics
 
162
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
163
  # Format text prompt
164
  run_n_segments = min(args.run_n_segments+1, len(lyrics))
165
+ raw_output = None
166
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Stage1 inference...")):
167
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
168
  guidance_scale = 1.5 if i <=1 else 1.2
 
225
  raw_output = output_seq
226
 
227
  # save raw output and check sanity
228
+ if raw_output is None:
229
+ raise ValueError("No valid segments were processed. Check your lyrics format and run_n_segments parameter.")
230
  ids = raw_output[0].cpu().numpy()
231
  soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
232
  eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()