Spaces:
Sleeping
Sleeping
futurespyhi
commited on
Commit
·
6d4b73b
1
Parent(s):
4b92fef
1.modify split_lyrics regex 2.define raw_output outside loop
Browse files- YuE/inference/infer.py +4 -1
YuE/inference/infer.py
CHANGED
|
@@ -131,7 +131,7 @@ def encode_audio(codec_model, audio_prompt, device, target_bw=0.5):
|
|
| 131 |
return raw_codes
|
| 132 |
|
| 133 |
def split_lyrics(lyrics):
|
| 134 |
-
pattern = r"\[(
|
| 135 |
segments = re.findall(pattern, lyrics, re.DOTALL)
|
| 136 |
structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
|
| 137 |
return structured_lyrics
|
|
@@ -162,6 +162,7 @@ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
|
|
| 162 |
end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
| 163 |
# Format text prompt
|
| 164 |
run_n_segments = min(args.run_n_segments+1, len(lyrics))
|
|
|
|
| 165 |
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Stage1 inference...")):
|
| 166 |
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
| 167 |
guidance_scale = 1.5 if i <=1 else 1.2
|
|
@@ -224,6 +225,8 @@ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Stage1 inference
|
|
| 224 |
raw_output = output_seq
|
| 225 |
|
| 226 |
# save raw output and check sanity
|
|
|
|
|
|
|
| 227 |
ids = raw_output[0].cpu().numpy()
|
| 228 |
soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
|
| 229 |
eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
|
|
|
|
| 131 |
return raw_codes
|
| 132 |
|
| 133 |
def split_lyrics(lyrics):
|
| 134 |
+
pattern = r"\[([^]]+)\](.*?)(?=\[|\Z)"
|
| 135 |
segments = re.findall(pattern, lyrics, re.DOTALL)
|
| 136 |
structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
|
| 137 |
return structured_lyrics
|
|
|
|
| 162 |
end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
| 163 |
# Format text prompt
|
| 164 |
run_n_segments = min(args.run_n_segments+1, len(lyrics))
|
| 165 |
+
raw_output = None
|
| 166 |
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments], desc="Stage1 inference...")):
|
| 167 |
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
| 168 |
guidance_scale = 1.5 if i <=1 else 1.2
|
|
|
|
| 225 |
raw_output = output_seq
|
| 226 |
|
| 227 |
# save raw output and check sanity
|
| 228 |
+
if raw_output is None:
|
| 229 |
+
raise ValueError("No valid segments were processed. Check your lyrics format and run_n_segments parameter.")
|
| 230 |
ids = raw_output[0].cpu().numpy()
|
| 231 |
soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
|
| 232 |
eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
|