Convert to MCP Client
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +1 -1
- app.py +224 -329
- audiocraft/__init__.py +0 -10
- audiocraft/data/__init__.py +0 -8
- audiocraft/data/audio.py +0 -422
- audiocraft/data/audio_dataset.py +0 -587
- audiocraft/data/audio_utils.py +0 -296
- audiocraft/data/info_audio_dataset.py +0 -110
- audiocraft/data/zip.py +0 -76
- audiocraft/environment.py +0 -176
- audiocraft/models/__init__.py +0 -21
- audiocraft/models/builders.py +0 -351
- audiocraft/models/encodec.py +0 -506
- audiocraft/models/flow_matching.py +0 -516
- audiocraft/models/genmodel.py +0 -273
- audiocraft/models/lm.py +0 -588
- audiocraft/models/lm_magnet.py +0 -500
- audiocraft/models/loaders.py +0 -291
- audiocraft/models/magnet.py +0 -88
- audiocraft/models/musicgen.py +0 -566
- audiocraft/models/unet.py +0 -214
- audiocraft/modules/__init__.py +0 -21
- audiocraft/modules/activations.py +0 -96
- audiocraft/modules/chroma.py +0 -66
- audiocraft/modules/codebooks_patterns.py +0 -548
- audiocraft/modules/conditioners.py +0 -1763
- audiocraft/modules/conv.py +0 -245
- audiocraft/modules/diffusion_schedule.py +0 -272
- audiocraft/modules/jasco_conditioners.py +0 -300
- audiocraft/modules/lstm.py +0 -25
- audiocraft/modules/rope.py +0 -125
- audiocraft/modules/seanet.py +0 -258
- audiocraft/modules/streaming.py +0 -135
- audiocraft/modules/transformer.py +0 -755
- audiocraft/modules/unet_transformer.py +0 -67
- audiocraft/py.typed +0 -0
- audiocraft/quantization/__init__.py +0 -9
- audiocraft/quantization/base.py +0 -107
- audiocraft/quantization/core_vq.py +0 -405
- audiocraft/quantization/vq.py +0 -116
- audiocraft/utils/__init__.py +0 -5
- audiocraft/utils/autocast.py +0 -40
- audiocraft/utils/cache.py +0 -324
- audiocraft/utils/cluster.py +0 -75
- audiocraft/utils/export.py +0 -79
- audiocraft/utils/export_legacy.py +0 -56
- audiocraft/utils/extend.py +0 -440
- audiocraft/utils/notebook.py +0 -32
- audiocraft/utils/utils.py +0 -328
- modules/constants.py +63 -0
README.md
CHANGED
|
@@ -10,7 +10,7 @@ app_file: app.py
|
|
| 10 |
pinned: true
|
| 11 |
license: creativeml-openrail-m
|
| 12 |
tags:
|
| 13 |
-
agent-demo-track
|
| 14 |
- musicgen
|
| 15 |
- unlimited
|
| 16 |
- user history
|
|
|
|
| 10 |
pinned: true
|
| 11 |
license: creativeml-openrail-m
|
| 12 |
tags:
|
| 13 |
+
- agent-demo-track
|
| 14 |
- musicgen
|
| 15 |
- unlimited
|
| 16 |
- user history
|
app.py
CHANGED
|
@@ -19,11 +19,6 @@ import typing as tp
|
|
| 19 |
import warnings
|
| 20 |
import gc
|
| 21 |
from tqdm import tqdm
|
| 22 |
-
from audiocraft.models import MusicGen
|
| 23 |
-
from audiocraft.data.audio import audio_write
|
| 24 |
-
from audiocraft.data.audio_utils import apply_fade, apply_tafade, apply_splice_effect
|
| 25 |
-
from audiocraft.utils.extend import generate_music_segments, add_settings_to_image, INTERRUPTING
|
| 26 |
-
from audiocraft.utils import utils
|
| 27 |
import numpy as np
|
| 28 |
import random
|
| 29 |
import shutil
|
|
@@ -35,9 +30,14 @@ from modules.version_info import versions_html, commit_hash, get_xformers_versio
|
|
| 35 |
from modules.gradio import *
|
| 36 |
from modules.file_utils import get_file_parts, get_filename_from_filepath, convert_title_to_filename, get_unique_file_path, delete_file
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
MODEL = None
|
| 39 |
MODELS = None
|
| 40 |
-
IS_SHARED_SPACE = "
|
| 41 |
INTERRUPTED = False
|
| 42 |
UNLOAD_MODEL = False
|
| 43 |
MOVE_TO_CPU = False
|
|
@@ -239,343 +239,238 @@ def load_melody_filepath(melody_filepath, title, assigned_model, topp, temperatu
|
|
| 239 |
|
| 240 |
return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=-1), gr.update(value=assigned_model, interactive=True), gr.update(value=topp), gr.update(value=temperature), gr.update(value=cfg_coef), gr.update(maximum=MAX_OVERLAP)
|
| 241 |
|
| 242 |
-
def predict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
"""
|
| 244 |
-
Generate music and video
|
| 245 |
-
|
| 246 |
-
Args:
|
| 247 |
-
model (str): Model name to use for generation.
|
| 248 |
-
text (str): Prompt describing the music.
|
| 249 |
-
melody_filepath (str): Path to melody conditioning file. default to None.
|
| 250 |
-
duration (int): Total duration in seconds.
|
| 251 |
-
dimension (int): Audio stacking/concatenation dimension.
|
| 252 |
-
topk (int): Top-k sampling value.
|
| 253 |
-
topp (float): Top-p sampling value.
|
| 254 |
-
temperature (float): Sampling temperature.
|
| 255 |
-
cfg_coef (float): Classifier-free guidance coefficient.
|
| 256 |
-
background (str): Path to background image. default to "./assets/background.png".
|
| 257 |
-
title (str): Song title.
|
| 258 |
-
settings_font (str): Path to font file.
|
| 259 |
-
settings_font_color (str): Font color for settings text.
|
| 260 |
-
seed (int): Random seed.
|
| 261 |
-
overlap (int, optional): Segment overlap in seconds.
|
| 262 |
-
prompt_index (int, optional): Melody segment index.
|
| 263 |
-
include_title (bool, optional): Whether to add title to video.
|
| 264 |
-
include_settings (bool, optional): Whether to add settings to video.
|
| 265 |
-
harmony_only (bool, optional): Whether to use harmony only.
|
| 266 |
-
profile (gr.OAuthProfile): User profile.
|
| 267 |
-
segment_length (int, optional): Segment length in seconds.
|
| 268 |
-
settings_font_size (int, optional): Font size for settings text.
|
| 269 |
-
settings_animate_waveform (bool, optional): Animate waveform in video.
|
| 270 |
-
video_orientation (str, optional): Video orientation.
|
| 271 |
-
excerpt_duration (float, optional): Excerpt duration for style conditioning.
|
| 272 |
-
progress (gr.Progress, optional): Gradio progress tracker.
|
| 273 |
-
|
| 274 |
-
Returns:
|
| 275 |
-
tuple: (waveform_video_path, wave_file_path, seed_used)
|
| 276 |
"""
|
| 277 |
-
global
|
| 278 |
-
output_segments = None
|
| 279 |
-
melody_name = "Not Used"
|
| 280 |
-
melody_extension = "Not Used"
|
| 281 |
-
melody = None
|
| 282 |
-
if melody_filepath in ["None", ""]:
|
| 283 |
-
melody_filepath = None
|
| 284 |
-
|
| 285 |
-
if background in ["None", ""]:
|
| 286 |
-
background = "./assets/background.png"
|
| 287 |
-
|
| 288 |
-
if melody_filepath:
|
| 289 |
-
melody_name, melody_extension = get_filename_from_filepath(melody_filepath)
|
| 290 |
-
melody = get_melody(melody_filepath)
|
| 291 |
-
|
| 292 |
INTERRUPTED = False
|
| 293 |
INTERRUPTING = False
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
try:
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
segment_duration = duration + overlap
|
| 339 |
-
else:
|
| 340 |
-
segment_duration = MODEL.lm.cfg.dataset.segment_duration
|
| 341 |
-
if (segment_length + overlap) < segment_duration:
|
| 342 |
-
segment_duration = segment_length + overlap
|
| 343 |
-
# implement seed
|
| 344 |
-
if seed < 0:
|
| 345 |
-
seed = random.randint(0, 0xffff_ffff_ffff)
|
| 346 |
-
torch.manual_seed(seed)
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
print(f'Segment duration: {segment_duration}, duration: {duration}, overlap: {overlap}')
|
| 350 |
-
if ("style" in model) and melody:
|
| 351 |
-
# style and text-to-music
|
| 352 |
-
MODEL.set_generation_params(
|
| 353 |
-
use_sampling=True,
|
| 354 |
-
top_k=topk,
|
| 355 |
-
top_p=topp,
|
| 356 |
-
temperature=temperature,
|
| 357 |
-
cfg_coef=cfg_coef,
|
| 358 |
-
duration=segment_duration,
|
| 359 |
-
two_step_cfg=False,
|
| 360 |
-
cfg_coef_beta=5, # double CFG is only useful for text-and-style conditioning
|
| 361 |
-
)
|
| 362 |
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
excerpt_length=excerpt_duration, # the length in seconds that is taken by the model in the provided excerpt, can be
|
| 369 |
-
# between 1.5 and 4.5 seconds but it has to be shortest to the length of the provided conditioning
|
| 370 |
-
)
|
| 371 |
-
else:
|
| 372 |
-
MODEL.set_generation_params(
|
| 373 |
-
use_sampling=True,
|
| 374 |
-
top_k=topk,
|
| 375 |
-
top_p=topp,
|
| 376 |
-
temperature=temperature,
|
| 377 |
-
cfg_coef=cfg_coef,
|
| 378 |
-
duration=segment_duration,
|
| 379 |
-
two_step_cfg=False,
|
| 380 |
-
extend_stride=2,
|
| 381 |
-
rep_penalty=0.5,
|
| 382 |
-
cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning
|
| 383 |
)
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
# return excess duration, load next model and continue in loop structure building up output_segments
|
| 389 |
-
if duration > MODEL.duration:
|
| 390 |
-
output_segments, duration = generate_music_segments(text, melody, seed, MODEL, duration, overlap, MODEL.duration, prompt_index, harmony_only, excerpt_duration, progress=gr.Progress(track_tqdm=True))
|
| 391 |
-
else:
|
| 392 |
-
# pure original code
|
| 393 |
-
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
|
| 394 |
-
print(melody.shape)
|
| 395 |
-
if melody.dim() == 2:
|
| 396 |
-
melody = melody[None]
|
| 397 |
-
melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
|
| 398 |
-
output = MODEL.generate_with_chroma(
|
| 399 |
-
descriptions=[text],
|
| 400 |
-
melody_wavs=melody,
|
| 401 |
-
melody_sample_rate=sr,
|
| 402 |
-
progress=False, progress_callback=gr.Progress(track_tqdm=True)
|
| 403 |
-
)
|
| 404 |
-
# All output_segments are populated, so we can break the loop or set duration to 0
|
| 405 |
-
break
|
| 406 |
-
else:
|
| 407 |
-
#output = MODEL.generate(descriptions=[text], progress=False)
|
| 408 |
-
if not output_segments:
|
| 409 |
-
next_segment = MODEL.generate(descriptions=[text], progress=False, progress_callback=gr.Progress(track_tqdm=True))
|
| 410 |
-
duration -= segment_duration
|
| 411 |
-
else:
|
| 412 |
-
last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:]
|
| 413 |
-
next_segment = MODEL.generate_continuation(last_chunk, MODEL.sample_rate, descriptions=[text], progress=False, progress_callback=gr.Progress(track_tqdm=True))
|
| 414 |
-
duration -= segment_duration - overlap
|
| 415 |
-
if next_segment != None:
|
| 416 |
-
output_segments.append(next_segment)
|
| 417 |
-
except Exception as e:
|
| 418 |
-
print(f"Error generating audio: {e}")
|
| 419 |
-
gr.Error(f"Error generating audio: {e}")
|
| 420 |
-
return None, None, seed
|
| 421 |
-
|
| 422 |
-
if INTERRUPTING:
|
| 423 |
-
INTERRUPTED = True
|
| 424 |
-
INTERRUPTING = False
|
| 425 |
-
print("Function execution interrupted!")
|
| 426 |
-
raise gr.Error("Interrupted.")
|
| 427 |
-
|
| 428 |
-
print(f"\nOutput segments: {len(output_segments)}\n")
|
| 429 |
-
if output_segments:
|
| 430 |
-
try:
|
| 431 |
-
# Combine the output segments into one long audio file or stack tracks
|
| 432 |
-
#output_segments = [segment.detach().cpu().float()[0] for segment in output_segments]
|
| 433 |
-
#output = torch.cat(output_segments, dim=dimension)
|
| 434 |
-
|
| 435 |
-
output = output_segments[0]
|
| 436 |
-
for i in range(1, len(output_segments)):
|
| 437 |
-
if overlap > 0:
|
| 438 |
-
overlap_samples = overlap * MODEL.sample_rate
|
| 439 |
-
#stack tracks and fade out/in
|
| 440 |
-
overlapping_output_fadeout = output[:, :, -overlap_samples:]
|
| 441 |
-
#overlapping_output_fadeout = apply_fade(overlapping_output_fadeout,sample_rate=MODEL.sample_rate,duration=overlap,out=True,start=True, curve_end=0.0, current_device=MODEL.device)
|
| 442 |
-
overlapping_output_fadeout = apply_tafade(overlapping_output_fadeout,sample_rate=MODEL.sample_rate,duration=overlap,out=True,start=True,shape="linear")
|
| 443 |
-
|
| 444 |
-
overlapping_output_fadein = output_segments[i][:, :, :overlap_samples]
|
| 445 |
-
#overlapping_output_fadein = apply_fade(overlapping_output_fadein,sample_rate=MODEL.sample_rate,duration=overlap,out=False,start=False, curve_start=0.0, current_device=MODEL.device)
|
| 446 |
-
overlapping_output_fadein = apply_tafade(overlapping_output_fadein,sample_rate=MODEL.sample_rate,duration=overlap,out=False,start=False, shape="linear")
|
| 447 |
-
|
| 448 |
-
overlapping_output = torch.cat([overlapping_output_fadeout[:, :, :-(overlap_samples // 2)], overlapping_output_fadein],dim=2)
|
| 449 |
-
###overlapping_output, overlap_sample_rate = apply_splice_effect(overlapping_output_fadeout, MODEL.sample_rate, overlapping_output_fadein, MODEL.sample_rate, overlap)
|
| 450 |
-
print(f" overlap size Fade:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}")
|
| 451 |
-
##overlapping_output = torch.cat([output[:, :, -overlap_samples:], output_segments[i][:, :, :overlap_samples]], dim=1) #stack tracks
|
| 452 |
-
##print(f" overlap size stack:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}")
|
| 453 |
-
#overlapping_output = torch.cat([output[:, :, -overlap_samples:], output_segments[i][:, :, :overlap_samples]], dim=2) #stack tracks
|
| 454 |
-
#print(f" overlap size cat:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}")
|
| 455 |
-
output = torch.cat([output[:, :, :-overlap_samples], overlapping_output, output_segments[i][:, :, overlap_samples:]], dim=dimension)
|
| 456 |
-
else:
|
| 457 |
-
output = torch.cat([output, output_segments[i]], dim=dimension)
|
| 458 |
-
output = output.detach().cpu().float()[0]
|
| 459 |
-
except Exception as e:
|
| 460 |
-
print(f"Error combining segments: {e}. Using the first segment only.")
|
| 461 |
-
output = output_segments[0].detach().cpu().float()[0]
|
| 462 |
-
else:
|
| 463 |
-
if (output is None) or (output.dim() == 0):
|
| 464 |
-
return None, None, seed
|
| 465 |
-
else:
|
| 466 |
-
output = output.detach().cpu().float()[0]
|
| 467 |
-
|
| 468 |
-
video_width, video_height = 768, 512
|
| 469 |
-
if video_orientation == "Portait":
|
| 470 |
-
video_width, video_height = 512, 768
|
| 471 |
-
|
| 472 |
-
title_file_name = convert_title_to_filename(title)
|
| 473 |
-
with NamedTemporaryFile("wb", suffix=".wav", delete=False, prefix=title_file_name) as file:
|
| 474 |
-
video_description = f"{text}\n Duration: {str(initial_duration)} Dimension: {dimension}\n Top-k:{topk} Top-p:{topp}\n Randomness:{temperature}\n cfg:{cfg_coef} overlap: {overlap}\n Seed: {seed}\n Model: {model}\n Melody Condition:{melody_name}\n Sample Segment: {prompt_index}"
|
| 475 |
-
if include_settings or include_title:
|
| 476 |
-
background = add_settings_to_image(title if include_title else "",video_description if include_settings else "",width=video_width, height=video_height, background_path=background,font=settings_font,font_color=settings_font_color, font_size=settings_font_size)
|
| 477 |
-
audio_write(
|
| 478 |
-
file.name, output, MODEL.sample_rate, strategy="loudness",
|
| 479 |
-
loudness_headroom_db=18, loudness_compressor=True, add_suffix=False, channels=2)
|
| 480 |
-
waveform_video_path = get_waveform(file.name, bg_image=background, bar_count=45, name=title_file_name, animate=settings_animate_waveform, progress=gr.Progress(track_tqdm=True))
|
| 481 |
-
# Remove the extension from file.name
|
| 482 |
-
file_name_without_extension = os.path.splitext(file.name)[0]
|
| 483 |
-
# Get the directory, filename, name, extension, and new extension of the waveform video path
|
| 484 |
-
video_dir, video_name, video_name, video_ext, video_new_ext = get_file_parts(waveform_video_path)
|
| 485 |
-
|
| 486 |
-
new_video_path = get_unique_file_path(video_dir, title_file_name, video_new_ext)
|
| 487 |
-
|
| 488 |
-
mp4 = MP4(waveform_video_path)
|
| 489 |
-
mp4["©nam"] = title_file_name # Title tag
|
| 490 |
-
mp4["desc"] = f"{text}\n Duration: {str(initial_duration)}" # Description tag
|
| 491 |
-
|
| 492 |
-
commit = commit_hash()
|
| 493 |
-
metadata = {
|
| 494 |
-
"Title": title,
|
| 495 |
-
"Year": time.strftime("%Y"),
|
| 496 |
-
"prompt": text,
|
| 497 |
-
"negative_prompt": "",
|
| 498 |
-
"Seed": seed,
|
| 499 |
-
"steps": 1,
|
| 500 |
-
"wdth": video_width,
|
| 501 |
-
"hght": video_height,
|
| 502 |
-
"Dimension": dimension,
|
| 503 |
-
"Top-k": topk,
|
| 504 |
-
"Top-p": topp,
|
| 505 |
-
"Randomness": temperature,
|
| 506 |
-
"cfg": cfg_coef,
|
| 507 |
-
"overlap": overlap,
|
| 508 |
-
"Melody Condition": melody_name,
|
| 509 |
-
"Sample Segment": prompt_index,
|
| 510 |
-
"Duration": initial_duration,
|
| 511 |
-
"Audio": file.name,
|
| 512 |
-
"font": settings_font,
|
| 513 |
-
"font_color": settings_font_color,
|
| 514 |
-
"font_size": settings_font_size,
|
| 515 |
-
"harmony_only": harmony_only,
|
| 516 |
-
"background": background,
|
| 517 |
-
"include_title": include_title,
|
| 518 |
-
"include_settings": include_settings,
|
| 519 |
-
"profile": "Satoshi Nakamoto" if profile.value is None else profile.value.username,
|
| 520 |
-
"commit": commit_hash(),
|
| 521 |
-
"tag": git_tag(),
|
| 522 |
-
"version": gr.__version__,
|
| 523 |
-
"model_version": MODEL.version,
|
| 524 |
-
"model_name": MODEL.name,
|
| 525 |
-
"model_description": f"{MODEL.audio_channels} channels, {MODEL.sample_rate} Hz",
|
| 526 |
-
"melody_name": melody_name if melody_name else "",
|
| 527 |
-
"melody_extension": melody_extension if melody_extension else "",
|
| 528 |
-
"hostname": "https://huggingface.co/spaces/Surn/UnlimitedMusicGen",
|
| 529 |
-
"version": f"https://huggingface.co/spaces/Surn/UnlimitedMusicGen/commit/{'huggingface' if commit == '<none>' else commit}",
|
| 530 |
-
"python": sys.version,
|
| 531 |
-
"torch": getattr(torch, '__long_version__', torch.__version__),
|
| 532 |
-
"xformers": get_xformers_version(),
|
| 533 |
-
"gradio": gr.__version__,
|
| 534 |
-
"huggingface_space": os.environ.get('SPACE_ID', ''),
|
| 535 |
-
"CUDA": f"{'CUDA is available. device: ' + torch.cuda.get_device_name(0) + ' version: ' + torch.version.cuda if torch.cuda.is_available() else 'CUDA is not available.'}",
|
| 536 |
-
}
|
| 537 |
-
# Add additional metadata from the metadata dictionary (if it exists)
|
| 538 |
-
for key, value in metadata.items():
|
| 539 |
-
mp4[key] = str(value) # Convert values to strings as required by mutagen
|
| 540 |
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
image=background,
|
| 554 |
-
audio=file.name,
|
| 555 |
-
video=waveform_video_path,
|
| 556 |
-
label=title,
|
| 557 |
-
metadata=metadata,
|
| 558 |
-
progress=gr.Progress(track_tqdm=True)
|
| 559 |
-
)
|
| 560 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
MODEL = None
|
| 566 |
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
#gc.collect()
|
| 572 |
-
|
| 573 |
-
# Synchronize CUDA streams
|
| 574 |
-
torch.cuda.synchronize()
|
| 575 |
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
|
| 580 |
gr.set_static_paths(paths=["fonts/","assets/","images/"])
|
| 581 |
def ui(**kwargs):
|
|
|
|
| 19 |
import warnings
|
| 20 |
import gc
|
| 21 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
import numpy as np
|
| 23 |
import random
|
| 24 |
import shutil
|
|
|
|
| 30 |
from modules.gradio import *
|
| 31 |
from modules.file_utils import get_file_parts, get_filename_from_filepath, convert_title_to_filename, get_unique_file_path, delete_file
|
| 32 |
|
| 33 |
+
# Added for MCP call
|
| 34 |
+
from smolagents.mcp_client import MCPClient
|
| 35 |
+
from modules.storage import upload_files_to_repo # Added import
|
| 36 |
+
from modules.constants import HF_REPO_ID # Added import
|
| 37 |
+
|
| 38 |
MODEL = None
|
| 39 |
MODELS = None
|
| 40 |
+
IS_SHARED_SPACE = "Agents-MCP-Hackathon/UnlimitedMusicGen" in os.environ.get('SPACE_ID', '')
|
| 41 |
INTERRUPTED = False
|
| 42 |
UNLOAD_MODEL = False
|
| 43 |
MOVE_TO_CPU = False
|
|
|
|
| 239 |
|
| 240 |
return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=-1), gr.update(value=assigned_model, interactive=True), gr.update(value=topp), gr.update(value=temperature), gr.update(value=cfg_coef), gr.update(maximum=MAX_OVERLAP)
|
| 241 |
|
| 242 |
+
def predict(
|
| 243 |
+
model_name_arg, # Renamed from 'model'
|
| 244 |
+
text_arg,
|
| 245 |
+
melody_filepath_arg,
|
| 246 |
+
duration_arg,
|
| 247 |
+
dimension_arg,
|
| 248 |
+
topk_arg,
|
| 249 |
+
topp_arg,
|
| 250 |
+
temperature_arg,
|
| 251 |
+
cfg_coef_arg,
|
| 252 |
+
background_image_arg, # Renamed from 'background'
|
| 253 |
+
title_arg,
|
| 254 |
+
settings_font_path_arg, # Renamed from 'settings_font'
|
| 255 |
+
settings_font_color_arg,
|
| 256 |
+
seed_arg,
|
| 257 |
+
overlap_arg=1,
|
| 258 |
+
prompt_index_arg=0,
|
| 259 |
+
include_title_arg=True,
|
| 260 |
+
include_settings_arg=True,
|
| 261 |
+
harmony_only_arg=False,
|
| 262 |
+
profile_arg: tp.Optional[gr.OAuthProfile] = None, # Type hint for clarity, Gradio passes OAuthProfile or None
|
| 263 |
+
segment_length_arg=30,
|
| 264 |
+
settings_font_size_arg=28,
|
| 265 |
+
settings_animate_waveform_arg=False,
|
| 266 |
+
video_orientation_arg="Landscape",
|
| 267 |
+
excerpt_duration_arg=3.5,
|
| 268 |
+
progress_arg=gr.Progress(track_tqdm=True) # Renamed from 'progress', Gradio handles this
|
| 269 |
+
):
|
| 270 |
"""
|
| 271 |
+
Generate music and video by calling a remote MCP endpoint tool.
|
| 272 |
+
This function replaces the original local model inference.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
"""
|
| 274 |
+
global INTERRUPTED, INTERRUPTING # Retained, though effect on remote job is indirect
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
INTERRUPTED = False
|
| 276 |
INTERRUPTING = False
|
| 277 |
+
|
| 278 |
+
# Helper to get value if it's a Gradio State object
|
| 279 |
+
def get_value_if_state(arg):
|
| 280 |
+
if hasattr(arg, 'value') and arg.value is not None:
|
| 281 |
+
return arg.value
|
| 282 |
+
return arg
|
| 283 |
+
|
| 284 |
+
model_name_arg = get_value_if_state(model_name_arg)
|
| 285 |
+
text_arg = get_value_if_state(text_arg)
|
| 286 |
+
melody_filepath_arg = get_value_if_state(melody_filepath_arg)
|
| 287 |
+
duration_arg = get_value_if_state(duration_arg)
|
| 288 |
+
dimension_arg = get_value_if_state(dimension_arg)
|
| 289 |
+
topk_arg = get_value_if_state(topk_arg)
|
| 290 |
+
topp_arg = get_value_if_state(topp_arg)
|
| 291 |
+
temperature_arg = get_value_if_state(temperature_arg)
|
| 292 |
+
cfg_coef_arg = get_value_if_state(cfg_coef_arg)
|
| 293 |
+
background_image_arg = get_value_if_state(background_image_arg)
|
| 294 |
+
title_arg = get_value_if_state(title_arg)
|
| 295 |
+
settings_font_path_arg = get_value_if_state(settings_font_path_arg)
|
| 296 |
+
settings_font_color_arg = get_value_if_state(settings_font_color_arg)
|
| 297 |
+
seed_arg = get_value_if_state(seed_arg)
|
| 298 |
+
overlap_arg = get_value_if_state(overlap_arg)
|
| 299 |
+
prompt_index_arg = get_value_if_state(prompt_index_arg)
|
| 300 |
+
include_title_arg = get_value_if_state(include_title_arg)
|
| 301 |
+
include_settings_arg = get_value_if_state(include_settings_arg)
|
| 302 |
+
# harmony_only_arg is handled specifically below
|
| 303 |
+
# profile_arg is handled specifically below
|
| 304 |
+
segment_length_arg = get_value_if_state(segment_length_arg)
|
| 305 |
+
settings_font_size_arg = get_value_if_state(settings_font_size_arg)
|
| 306 |
+
settings_animate_waveform_arg = get_value_if_state(settings_animate_waveform_arg)
|
| 307 |
+
video_orientation_arg = get_value_if_state(video_orientation_arg)
|
| 308 |
+
excerpt_duration_arg = get_value_if_state(excerpt_duration_arg)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
print("Initiating MCP call to https://surn-unlimitedmusicgen.hf.space/gradio_api/mcp/sse tool UnlimitedMusicGen_predict")
|
| 312 |
+
|
| 313 |
+
mcp_client = None
|
| 314 |
+
melody_file_url = None # Changed from melody_file_obj
|
| 315 |
+
background_image_url = None # Changed from background_reference
|
| 316 |
|
| 317 |
try:
|
| 318 |
+
# Upload files to Hugging Face Hub and get URLs
|
| 319 |
+
files_to_upload = []
|
| 320 |
+
if melody_filepath_arg and melody_filepath_arg not in ["None", ""]:
|
| 321 |
+
files_to_upload.append(melody_filepath_arg)
|
| 322 |
+
if background_image_arg and background_image_arg not in ["None", ""] and not background_image_arg.startswith("http"): # only upload if not already a URL
|
| 323 |
+
files_to_upload.append(background_image_arg)
|
| 324 |
+
|
| 325 |
+
uploaded_file_urls = {}
|
| 326 |
+
if files_to_upload:
|
| 327 |
+
# Use a unique folder name for each upload session, e.g., based on timestamp or a random string
|
| 328 |
+
# For simplicity, using a fixed folder name here, but consider making it unique.
|
| 329 |
+
# The username from profile_arg could be used to create a user-specific folder.
|
| 330 |
+
profile_username_for_folder = "default_user"
|
| 331 |
+
if profile_arg:
|
| 332 |
+
# Check if profile_arg is a Gradio State object holding an OAuthProfile or string
|
| 333 |
+
actual_profile_data = profile_arg
|
| 334 |
+
if hasattr(profile_arg, 'value') and profile_arg.value is not None: # Handles gr.State wrapping OAuthProfile or string
|
| 335 |
+
actual_profile_data = profile_arg.value
|
| 336 |
+
|
| 337 |
+
if hasattr(actual_profile_data, 'username') and actual_profile_data.username: # OAuthProfile object
|
| 338 |
+
profile_username_for_folder = actual_profile_data.username
|
| 339 |
+
elif isinstance(actual_profile_data, str) and actual_profile_data: # String username
|
| 340 |
+
profile_username_for_folder = actual_profile_data
|
| 341 |
+
|
| 342 |
+
folder_name = f"user_uploads/{profile_username_for_folder}/{time.strftime('%Y%m%d%H%M%S')}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
+
upload_results = upload_files_to_repo(
|
| 345 |
+
files=files_to_upload,
|
| 346 |
+
repo_id=HF_REPO_ID, # Make sure HF_REPO_ID is defined in constants
|
| 347 |
+
folder_name=folder_name,
|
| 348 |
+
create_permalink=False # We need individual links
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
)
|
| 350 |
+
print(f"Upload results: {upload_results}")
|
| 351 |
+
|
| 352 |
+
if isinstance(upload_results, list):
|
| 353 |
+
for i, file_path in enumerate(files_to_upload):
|
| 354 |
+
original_filename = os.path.basename(file_path)
|
| 355 |
+
# Find the corresponding URL from upload_results
|
| 356 |
+
# The upload_results list contains tuples of (response, link)
|
| 357 |
+
# We need to match the uploaded file with its original path to assign the correct URL
|
| 358 |
+
# Assuming the order is preserved or filenames in links are reliable
|
| 359 |
+
for _, link in upload_results:
|
| 360 |
+
if original_filename in link:
|
| 361 |
+
uploaded_file_urls[file_path] = link
|
| 362 |
+
break
|
| 363 |
+
else: # Handle dict case or errors if necessary, though create_permalink=False should yield a list
|
| 364 |
+
print(f"Warning: Expected a list from upload_files_to_repo, got {type(upload_results)}")
|
| 365 |
+
|
| 366 |
+
if melody_filepath_arg and melody_filepath_arg in uploaded_file_urls:
|
| 367 |
+
melody_file_url = uploaded_file_urls[melody_filepath_arg]
|
| 368 |
+
print(f"Melody file uploaded to: {melody_file_url}")
|
| 369 |
+
elif melody_filepath_arg and melody_filepath_arg not in ["None", ""]: # File was provided but not uploaded (e.g. error)
|
| 370 |
+
print(f"Warning: Melody file {melody_filepath_arg} was provided but not successfully uploaded or URL not found.")
|
| 371 |
+
|
| 372 |
+
if background_image_arg and background_image_arg in uploaded_file_urls:
|
| 373 |
+
background_image_url = uploaded_file_urls[background_image_arg]
|
| 374 |
+
print(f"Background image uploaded to: {background_image_url}")
|
| 375 |
+
elif background_image_arg and background_image_arg.startswith("http"):
|
| 376 |
+
background_image_url = background_image_arg # It's already a URL
|
| 377 |
+
print(f"Using existing background image URL: {background_image_url}")
|
| 378 |
+
elif background_image_arg and background_image_arg not in ["None", ""]: # File was provided but not uploaded
|
| 379 |
+
print(f"Warning: Background image {background_image_arg} was provided but not successfully uploaded or URL not found.")
|
| 380 |
+
|
| 381 |
+
mcp_client = MCPClient({"url": "https://surn-unlimitedmusicgen.hf.space/gradio_api/mcp/sse"})
|
| 382 |
+
tools = mcp_client.get_tools()
|
| 383 |
+
|
| 384 |
+
predict_tool = next((t for t in tools if t.name == "UnlimitedMusicGen_predict"), None)
|
| 385 |
|
| 386 |
+
if not predict_tool:
|
| 387 |
+
raise gr.Error("MCP tool 'UnlimitedMusicGen_predict' not found on the server.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
+
profile_username_to_send = "Satoshi Nakamoto"
|
| 390 |
+
if profile_arg:
|
| 391 |
+
actual_profile_data = profile_arg
|
| 392 |
+
# Unwrap if it's a gr.State object
|
| 393 |
+
if hasattr(profile_arg, 'value') and profile_arg.value is not None:
|
| 394 |
+
actual_profile_data = profile_arg.value
|
| 395 |
+
|
| 396 |
+
# Now actual_profile_data is either an OAuthProfile or a string username
|
| 397 |
+
if hasattr(actual_profile_data, 'username') and actual_profile_data.username: # OAuthProfile
|
| 398 |
+
profile_username_to_send = actual_profile_data.username
|
| 399 |
+
elif isinstance(actual_profile_data, str) and actual_profile_data: # string username
|
| 400 |
+
profile_username_to_send = actual_profile_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
|
| 402 |
+
actual_harmony_only = False
|
| 403 |
+
if isinstance(harmony_only_arg, str):
|
| 404 |
+
actual_harmony_only = harmony_only_arg.lower() == "yes"
|
| 405 |
+
elif isinstance(harmony_only_arg, bool):
|
| 406 |
+
actual_harmony_only = harmony_only_arg
|
| 407 |
+
|
| 408 |
+
tool_args = {
|
| 409 |
+
"model_name_arg": model_name_arg,
|
| 410 |
+
"text_arg": text_arg,
|
| 411 |
+
"melody_filepath_arg": melody_file_url, # Pass URL instead of file object
|
| 412 |
+
"duration_arg": duration_arg,
|
| 413 |
+
"dimension_arg": dimension_arg,
|
| 414 |
+
"topk_arg": topk_arg,
|
| 415 |
+
"topp_arg": topp_arg,
|
| 416 |
+
"temperature_arg": temperature_arg,
|
| 417 |
+
"cfg_coef_arg": cfg_coef_arg,
|
| 418 |
+
"background_image_arg": background_image_url, # Pass URL
|
| 419 |
+
"title_arg": title_arg,
|
| 420 |
+
"settings_font_path_arg": settings_font_path_arg,
|
| 421 |
+
"settings_font_color_arg": settings_font_color_arg,
|
| 422 |
+
"seed_arg": seed_arg,
|
| 423 |
+
"overlap_arg": overlap_arg,
|
| 424 |
+
"prompt_index_arg": prompt_index_arg,
|
| 425 |
+
"include_title_arg": include_title_arg,
|
| 426 |
+
"include_settings_arg": include_settings_arg,
|
| 427 |
+
"harmony_only_arg": actual_harmony_only,
|
| 428 |
+
"profile_arg": profile_username_to_send,
|
| 429 |
+
"segment_length_arg": segment_length_arg,
|
| 430 |
+
"settings_font_size_arg": settings_font_size_arg,
|
| 431 |
+
"settings_animate_waveform_arg": settings_animate_waveform_arg,
|
| 432 |
+
"video_orientation_arg": video_orientation_arg,
|
| 433 |
+
"excerpt_duration_arg": excerpt_duration_arg,
|
| 434 |
+
}
|
| 435 |
|
| 436 |
+
print(f"Calling remote MCP tool 'UnlimitedMusicGen_predict' with arguments (URLs for files).")
|
| 437 |
+
results = predict_tool(**tool_args)
|
| 438 |
+
print(f"MCP tool call completed. Raw results: {results}")
|
|
|
|
| 439 |
|
| 440 |
+
if not isinstance(results, (list, tuple)) or len(results) != 3:
|
| 441 |
+
raise gr.Error(f"MCP tool 'UnlimitedMusicGen_predict' did not return the expected 3 values. Received: {results}")
|
| 442 |
+
|
| 443 |
+
waveform_video_path, wave_file_path, seed_used = results
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
+
if not ((waveform_video_path is None or isinstance(waveform_video_path, str)) and
|
| 446 |
+
(wave_file_path is None or isinstance(wave_file_path, str))):
|
| 447 |
+
error_msg = (f"MCP tool returned invalid file paths. "
|
| 448 |
+
f"Video path type: {type(waveform_video_path)}, "
|
| 449 |
+
f"Audio path type: {type(wave_file_path)}")
|
| 450 |
+
raise gr.Error(error_msg)
|
| 451 |
+
|
| 452 |
+
if not isinstance(seed_used, (int, float)): # Allow float for seed then cast later
|
| 453 |
+
raise gr.Error(f"MCP tool returned a non-numeric seed. Received type: {type(seed_used)}, value: {seed_used}")
|
| 454 |
+
|
| 455 |
+
return waveform_video_path, wave_file_path, int(seed_used)
|
| 456 |
+
|
| 457 |
+
except Exception as e:
|
| 458 |
+
error_message = f"Error during MCP tool call or file upload: {str(e)}"
|
| 459 |
+
print(error_message)
|
| 460 |
+
import traceback
|
| 461 |
+
traceback.print_exc()
|
| 462 |
+
if isinstance(e, gr.Error):
|
| 463 |
+
raise
|
| 464 |
+
else:
|
| 465 |
+
raise gr.Error(error_message)
|
| 466 |
+
finally:
|
| 467 |
+
# No file objects to close here as we are passing URLs
|
| 468 |
+
if mcp_client:
|
| 469 |
+
try:
|
| 470 |
+
mcp_client.disconnect()
|
| 471 |
+
print("MCP client disconnected.")
|
| 472 |
+
except Exception as e_disconnect:
|
| 473 |
+
print(f"Error disconnecting MCP client: {e_disconnect}")
|
| 474 |
|
| 475 |
gr.set_static_paths(paths=["fonts/","assets/","images/"])
|
| 476 |
def ui(**kwargs):
|
audiocraft/__init__.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
# flake8: noqa
|
| 8 |
-
from . import data, modules, models
|
| 9 |
-
|
| 10 |
-
__version__ = '1.3.Surn-MCP'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/data/__init__.py
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
# flake8: noqa
|
| 8 |
-
from . import audio, audio_dataset, info_audio_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/data/audio.py
DELETED
|
@@ -1,422 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Audio IO methods are defined in this module (info, read, write),
|
| 9 |
-
We rely on av library for faster read when possible, otherwise on torchaudio.
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
from dataclasses import dataclass
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
import logging
|
| 15 |
-
import typing as tp
|
| 16 |
-
|
| 17 |
-
import numpy as np
|
| 18 |
-
import soundfile
|
| 19 |
-
import torch
|
| 20 |
-
from torch.nn import functional as F
|
| 21 |
-
import torchaudio as ta
|
| 22 |
-
|
| 23 |
-
import av
|
| 24 |
-
import subprocess as sp
|
| 25 |
-
|
| 26 |
-
from .audio_utils import f32_pcm, i16_pcm, normalize_audio, convert_audio
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
_av_initialized = False
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def _init_av():
|
| 33 |
-
global _av_initialized
|
| 34 |
-
if _av_initialized:
|
| 35 |
-
return
|
| 36 |
-
logger = logging.getLogger('libav.mp3')
|
| 37 |
-
logger.setLevel(logging.ERROR)
|
| 38 |
-
_av_initialized = True
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
@dataclass(frozen=True)
|
| 42 |
-
class AudioFileInfo:
|
| 43 |
-
sample_rate: int
|
| 44 |
-
duration: float
|
| 45 |
-
channels: int
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
| 49 |
-
_init_av()
|
| 50 |
-
with av.open(str(filepath)) as af:
|
| 51 |
-
stream = af.streams.audio[0]
|
| 52 |
-
sample_rate = stream.codec_context.sample_rate
|
| 53 |
-
duration = float(stream.duration * stream.time_base)
|
| 54 |
-
channels = stream.channels
|
| 55 |
-
return AudioFileInfo(sample_rate, duration, channels)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
| 59 |
-
info = soundfile.info(filepath)
|
| 60 |
-
return AudioFileInfo(info.samplerate, info.duration, info.channels)
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
| 64 |
-
# torchaudio no longer returns useful duration informations for some formats like mp3s.
|
| 65 |
-
filepath = Path(filepath)
|
| 66 |
-
if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
|
| 67 |
-
# ffmpeg has some weird issue with flac.
|
| 68 |
-
return _soundfile_info(filepath)
|
| 69 |
-
else:
|
| 70 |
-
return _av_info(filepath)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
|
| 74 |
-
"""FFMPEG-based audio file reading using PyAV bindings.
|
| 75 |
-
Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
|
| 76 |
-
|
| 77 |
-
Args:
|
| 78 |
-
filepath (str or Path): Path to audio file to read.
|
| 79 |
-
seek_time (float): Time at which to start reading in the file.
|
| 80 |
-
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
| 81 |
-
Returns:
|
| 82 |
-
tuple of torch.Tensor, int: Tuple containing audio data and sample rate
|
| 83 |
-
"""
|
| 84 |
-
_init_av()
|
| 85 |
-
with av.open(str(filepath)) as af:
|
| 86 |
-
stream = af.streams.audio[0]
|
| 87 |
-
sr = stream.codec_context.sample_rate
|
| 88 |
-
num_frames = int(sr * duration) if duration >= 0 else -1
|
| 89 |
-
frame_offset = int(sr * seek_time)
|
| 90 |
-
# we need a small negative offset otherwise we get some edge artifact
|
| 91 |
-
# from the mp3 decoder.
|
| 92 |
-
af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
|
| 93 |
-
frames = []
|
| 94 |
-
length = 0
|
| 95 |
-
for frame in af.decode(streams=stream.index):
|
| 96 |
-
current_offset = int(frame.rate * frame.pts * frame.time_base)
|
| 97 |
-
strip = max(0, frame_offset - current_offset)
|
| 98 |
-
buf = torch.from_numpy(frame.to_ndarray())
|
| 99 |
-
if buf.shape[0] != stream.channels:
|
| 100 |
-
buf = buf.view(-1, stream.channels).t()
|
| 101 |
-
buf = buf[:, strip:]
|
| 102 |
-
frames.append(buf)
|
| 103 |
-
length += buf.shape[1]
|
| 104 |
-
if num_frames > 0 and length >= num_frames:
|
| 105 |
-
break
|
| 106 |
-
assert frames
|
| 107 |
-
# If the above assert fails, it is likely because we seeked past the end of file point,
|
| 108 |
-
# in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
|
| 109 |
-
# This will need proper debugging, in due time.
|
| 110 |
-
wav = torch.cat(frames, dim=1)
|
| 111 |
-
assert wav.shape[0] == stream.channels
|
| 112 |
-
if num_frames > 0:
|
| 113 |
-
wav = wav[:, :num_frames]
|
| 114 |
-
return f32_pcm(wav), sr
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
| 118 |
-
duration: float = -1.0, pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
|
| 119 |
-
"""Read audio by picking the most appropriate backend tool based on the audio format.
|
| 120 |
-
|
| 121 |
-
Args:
|
| 122 |
-
filepath (str or Path): Path to audio file to read.
|
| 123 |
-
seek_time (float): Time at which to start reading in the file.
|
| 124 |
-
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
| 125 |
-
pad (bool): Pad output audio if not reaching expected duration.
|
| 126 |
-
Returns:
|
| 127 |
-
tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
|
| 128 |
-
"""
|
| 129 |
-
fp = Path(filepath)
|
| 130 |
-
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
|
| 131 |
-
# There is some bug with ffmpeg and reading flac
|
| 132 |
-
info = _soundfile_info(filepath)
|
| 133 |
-
frames = -1 if duration <= 0 else int(duration * info.sample_rate)
|
| 134 |
-
frame_offset = int(seek_time * info.sample_rate)
|
| 135 |
-
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
|
| 136 |
-
assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
|
| 137 |
-
wav = torch.from_numpy(wav).t().contiguous()
|
| 138 |
-
if len(wav.shape) == 1:
|
| 139 |
-
wav = torch.unsqueeze(wav, 0)
|
| 140 |
-
elif (
|
| 141 |
-
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
|
| 142 |
-
and duration <= 0 and seek_time == 0
|
| 143 |
-
):
|
| 144 |
-
# Torchaudio is faster if we load an entire file at once.
|
| 145 |
-
wav, sr = ta.load(fp)
|
| 146 |
-
else:
|
| 147 |
-
wav, sr = _av_read(filepath, seek_time, duration)
|
| 148 |
-
if pad and duration > 0:
|
| 149 |
-
expected_frames = int(duration * sr)
|
| 150 |
-
wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
|
| 151 |
-
return wav, sr
|
| 152 |
-
|
| 153 |
-
def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
|
| 154 |
-
# ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
|
| 155 |
-
assert wav.dim() == 2, wav.shape
|
| 156 |
-
command = [
|
| 157 |
-
'ffmpeg',
|
| 158 |
-
'-loglevel', 'error',
|
| 159 |
-
'-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
|
| 160 |
-
'-i', '-'] + flags + [str(out_path)]
|
| 161 |
-
input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
|
| 162 |
-
sp.run(command, input=input_, check=True)
|
| 163 |
-
|
| 164 |
-
def audio_write(stem_name: tp.Union[str, Path],
|
| 165 |
-
wav: torch.Tensor, sample_rate: int,
|
| 166 |
-
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
|
| 167 |
-
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
| 168 |
-
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
| 169 |
-
loudness_compressor: bool = False,
|
| 170 |
-
log_clipping: bool = True, make_parent_dir: bool = True,
|
| 171 |
-
add_suffix: bool = True, channels:int = 1) -> Path:
|
| 172 |
-
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
|
| 173 |
-
|
| 174 |
-
Args:
|
| 175 |
-
stem_name (str or Path): Filename without extension which will be added automatically.
|
| 176 |
-
format (str): Either "wav" or "mp3".
|
| 177 |
-
mp3_rate (int): kbps when using mp3s.
|
| 178 |
-
normalize (bool): if `True` (default), normalizes according to the prescribed
|
| 179 |
-
strategy (see after). If `False`, the strategy is only used in case clipping
|
| 180 |
-
would happen.
|
| 181 |
-
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
| 182 |
-
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
| 183 |
-
with extra headroom to avoid clipping. 'clip' just clips.
|
| 184 |
-
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
| 185 |
-
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
| 186 |
-
than the `peak_clip` one to avoid further clipping.
|
| 187 |
-
loudness_headroom_db (float): Target loudness for loudness normalization.
|
| 188 |
-
loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
|
| 189 |
-
when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
|
| 190 |
-
occurs despite strategy (only for 'rms').
|
| 191 |
-
make_parent_dir (bool): Make parent directory if it doesn't exist.
|
| 192 |
-
Returns:
|
| 193 |
-
Path: Path of the saved audio.
|
| 194 |
-
"""
|
| 195 |
-
assert wav.dtype.is_floating_point, "wav is not floating point"
|
| 196 |
-
if wav.dim() == 1:
|
| 197 |
-
wav = wav[None]
|
| 198 |
-
elif wav.dim() > 2:
|
| 199 |
-
raise ValueError("Input wav should be at most 2 dimension.")
|
| 200 |
-
assert wav.isfinite().all()
|
| 201 |
-
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
|
| 202 |
-
rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
|
| 203 |
-
sample_rate=sample_rate, stem_name=str(stem_name))
|
| 204 |
-
if channels > 1:
|
| 205 |
-
wav = convert_audio(wav,sample_rate, sample_rate, channels)
|
| 206 |
-
kwargs: dict = {}
|
| 207 |
-
if format == 'mp3':
|
| 208 |
-
suffix = '.mp3'
|
| 209 |
-
kwargs.update({"compression": mp3_rate})
|
| 210 |
-
elif format == 'wav':
|
| 211 |
-
wav = i16_pcm(wav)
|
| 212 |
-
suffix = '.wav'
|
| 213 |
-
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
|
| 214 |
-
else:
|
| 215 |
-
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
|
| 216 |
-
if not add_suffix:
|
| 217 |
-
suffix = ''
|
| 218 |
-
path = Path(str(stem_name) + suffix)
|
| 219 |
-
if make_parent_dir:
|
| 220 |
-
path.parent.mkdir(exist_ok=True, parents=True)
|
| 221 |
-
try:
|
| 222 |
-
ta.save(path, wav, sample_rate, **kwargs)
|
| 223 |
-
except Exception:
|
| 224 |
-
if path.exists():
|
| 225 |
-
# we do not want to leave half written files around.
|
| 226 |
-
path.unlink()
|
| 227 |
-
raise
|
| 228 |
-
return path
|
| 229 |
-
|
| 230 |
-
def audio_write2(stem_name: tp.Union[str, Path],
|
| 231 |
-
wav: torch.Tensor, sample_rate: int,
|
| 232 |
-
format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
|
| 233 |
-
normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
| 234 |
-
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
| 235 |
-
loudness_compressor: bool = False,
|
| 236 |
-
log_clipping: bool = True, make_parent_dir: bool = True,
|
| 237 |
-
add_suffix: bool = True) -> Path:
|
| 238 |
-
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
|
| 239 |
-
|
| 240 |
-
Args:
|
| 241 |
-
stem_name (str or Path): Filename without extension which will be added automatically.
|
| 242 |
-
wav (torch.Tensor): Audio data to save.
|
| 243 |
-
sample_rate (int): Sample rate of audio data.
|
| 244 |
-
format (str): Either "wav", "mp3", "ogg", or "flac".
|
| 245 |
-
mp3_rate (int): kbps when using mp3s.
|
| 246 |
-
ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
|
| 247 |
-
normalize (bool): if `True` (default), normalizes according to the prescribed
|
| 248 |
-
strategy (see after). If `False`, the strategy is only used in case clipping
|
| 249 |
-
would happen.
|
| 250 |
-
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
| 251 |
-
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
| 252 |
-
with extra headroom to avoid clipping. 'clip' just clips.
|
| 253 |
-
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
| 254 |
-
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
| 255 |
-
than the `peak_clip` one to avoid further clipping.
|
| 256 |
-
loudness_headroom_db (float): Target loudness for loudness normalization.
|
| 257 |
-
loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
|
| 258 |
-
when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
|
| 259 |
-
occurs despite strategy (only for 'rms').
|
| 260 |
-
make_parent_dir (bool): Make parent directory if it doesn't exist.
|
| 261 |
-
Returns:
|
| 262 |
-
Path: Path of the saved audio.
|
| 263 |
-
"""
|
| 264 |
-
assert wav.dtype.is_floating_point, "wav is not floating point"
|
| 265 |
-
if wav.dim() == 1:
|
| 266 |
-
wav = wav[None]
|
| 267 |
-
elif wav.dim() > 2:
|
| 268 |
-
raise ValueError("Input wav should be at most 2 dimension.")
|
| 269 |
-
assert wav.isfinite().all()
|
| 270 |
-
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
|
| 271 |
-
rms_headroom_db, loudness_headroom_db, loudness_compressor,
|
| 272 |
-
log_clipping=log_clipping, sample_rate=sample_rate,
|
| 273 |
-
stem_name=str(stem_name))
|
| 274 |
-
if format == 'mp3':
|
| 275 |
-
suffix = '.mp3'
|
| 276 |
-
flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
|
| 277 |
-
elif format == 'wav':
|
| 278 |
-
suffix = '.wav'
|
| 279 |
-
flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
|
| 280 |
-
elif format == 'ogg':
|
| 281 |
-
suffix = '.ogg'
|
| 282 |
-
flags = ['-f', 'ogg', '-c:a', 'libvorbis']
|
| 283 |
-
if ogg_rate is not None:
|
| 284 |
-
flags += ['-b:a', f'{ogg_rate}k']
|
| 285 |
-
elif format == 'flac':
|
| 286 |
-
suffix = '.flac'
|
| 287 |
-
flags = ['-f', 'flac']
|
| 288 |
-
else:
|
| 289 |
-
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
|
| 290 |
-
if not add_suffix:
|
| 291 |
-
suffix = ''
|
| 292 |
-
path = Path(str(stem_name) + suffix)
|
| 293 |
-
if make_parent_dir:
|
| 294 |
-
path.parent.mkdir(exist_ok=True, parents=True)
|
| 295 |
-
try:
|
| 296 |
-
_piping_to_ffmpeg(path, wav, sample_rate, flags)
|
| 297 |
-
except Exception:
|
| 298 |
-
if path.exists():
|
| 299 |
-
# we do not want to leave half written files around.
|
| 300 |
-
path.unlink()
|
| 301 |
-
raise
|
| 302 |
-
return path
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
def get_spec(y, sr=16000, n_fft=4096, hop_length=128, dur=8) -> np.ndarray:
|
| 306 |
-
"""Get the mel-spectrogram from the raw audio.
|
| 307 |
-
|
| 308 |
-
Args:
|
| 309 |
-
y (numpy array): raw input
|
| 310 |
-
sr (int): Sampling rate
|
| 311 |
-
n_fft (int): Number of samples per FFT. Default is 2048.
|
| 312 |
-
hop_length (int): Number of samples between successive frames. Default is 512.
|
| 313 |
-
dur (float): Maxium duration to get the spectrograms
|
| 314 |
-
Returns:
|
| 315 |
-
spectro histogram as a numpy array
|
| 316 |
-
"""
|
| 317 |
-
import librosa
|
| 318 |
-
import librosa.display
|
| 319 |
-
|
| 320 |
-
spectrogram = librosa.feature.melspectrogram(
|
| 321 |
-
y=y, sr=sr, n_fft=n_fft, hop_length=hop_length
|
| 322 |
-
)
|
| 323 |
-
spectrogram_db = librosa.power_to_db(spectrogram, ref=np.max)
|
| 324 |
-
return spectrogram_db
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
def save_spectrograms(
|
| 328 |
-
ys: tp.List[np.ndarray],
|
| 329 |
-
sr: int,
|
| 330 |
-
path: str,
|
| 331 |
-
names: tp.List[str],
|
| 332 |
-
n_fft: int = 4096,
|
| 333 |
-
hop_length: int = 128,
|
| 334 |
-
dur: float = 8.0,
|
| 335 |
-
):
|
| 336 |
-
"""Plot a spectrogram for an audio file.
|
| 337 |
-
|
| 338 |
-
Args:
|
| 339 |
-
ys: List of audio spectrograms
|
| 340 |
-
sr (int): Sampling rate of the audio file. Default is 22050 Hz.
|
| 341 |
-
path (str): Path to the plot file.
|
| 342 |
-
names: name of each spectrogram plot
|
| 343 |
-
n_fft (int): Number of samples per FFT. Default is 2048.
|
| 344 |
-
hop_length (int): Number of samples between successive frames. Default is 512.
|
| 345 |
-
dur (float): Maxium duration to plot the spectrograms
|
| 346 |
-
|
| 347 |
-
Returns:
|
| 348 |
-
None (plots the spectrogram using matplotlib)
|
| 349 |
-
"""
|
| 350 |
-
import matplotlib as mpl # type: ignore
|
| 351 |
-
import matplotlib.pyplot as plt # type: ignore
|
| 352 |
-
import librosa.display
|
| 353 |
-
|
| 354 |
-
if not names:
|
| 355 |
-
names = ["Ground Truth", "Audio Watermarked", "Watermark"]
|
| 356 |
-
ys = [wav[: int(dur * sr)] for wav in ys] # crop
|
| 357 |
-
assert len(names) == len(
|
| 358 |
-
ys
|
| 359 |
-
), f"There are {len(ys)} wavs but {len(names)} names ({names})"
|
| 360 |
-
|
| 361 |
-
# Set matplotlib stuff
|
| 362 |
-
BIGGER_SIZE = 10
|
| 363 |
-
SMALLER_SIZE = 8
|
| 364 |
-
linewidth = 234.8775 # linewidth in pt
|
| 365 |
-
|
| 366 |
-
plt.rc("font", size=BIGGER_SIZE, family="serif") # controls default text sizes
|
| 367 |
-
plt.rcParams["font.family"] = "DeJavu Serif"
|
| 368 |
-
plt.rcParams["font.serif"] = ["Times New Roman"]
|
| 369 |
-
|
| 370 |
-
plt.rc("axes", titlesize=BIGGER_SIZE) # fontsize of the axes title
|
| 371 |
-
plt.rc("axes", labelsize=BIGGER_SIZE) # fontsize of the x and y labels
|
| 372 |
-
plt.rc("xtick", labelsize=BIGGER_SIZE) # fontsize of the tick labels
|
| 373 |
-
plt.rc("ytick", labelsize=SMALLER_SIZE) # fontsize of the tick labels
|
| 374 |
-
plt.rc("legend", fontsize=BIGGER_SIZE) # legend fontsize
|
| 375 |
-
plt.rc("figure", titlesize=BIGGER_SIZE)
|
| 376 |
-
height = 1.6 * linewidth / 72.0
|
| 377 |
-
fig, ax = plt.subplots(
|
| 378 |
-
nrows=len(ys),
|
| 379 |
-
ncols=1,
|
| 380 |
-
sharex=True,
|
| 381 |
-
figsize=(linewidth / 72.0, height),
|
| 382 |
-
)
|
| 383 |
-
fig.tight_layout()
|
| 384 |
-
|
| 385 |
-
# Plot the spectrogram
|
| 386 |
-
|
| 387 |
-
for i, ysi in enumerate(ys):
|
| 388 |
-
spectrogram_db = get_spec(ysi, sr=sr, n_fft=n_fft, hop_length=hop_length)
|
| 389 |
-
if i == 0:
|
| 390 |
-
cax = fig.add_axes(
|
| 391 |
-
[
|
| 392 |
-
ax[0].get_position().x1 + 0.01, # type: ignore
|
| 393 |
-
ax[-1].get_position().y0,
|
| 394 |
-
0.02,
|
| 395 |
-
ax[0].get_position().y1 - ax[-1].get_position().y0,
|
| 396 |
-
]
|
| 397 |
-
)
|
| 398 |
-
fig.colorbar(
|
| 399 |
-
mpl.cm.ScalarMappable(
|
| 400 |
-
norm=mpl.colors.Normalize(
|
| 401 |
-
np.min(spectrogram_db), np.max(spectrogram_db)
|
| 402 |
-
),
|
| 403 |
-
cmap="magma",
|
| 404 |
-
),
|
| 405 |
-
ax=ax,
|
| 406 |
-
orientation="vertical",
|
| 407 |
-
format="%+2.0f dB",
|
| 408 |
-
cax=cax,
|
| 409 |
-
)
|
| 410 |
-
librosa.display.specshow(
|
| 411 |
-
spectrogram_db,
|
| 412 |
-
sr=sr,
|
| 413 |
-
hop_length=hop_length,
|
| 414 |
-
x_axis="time",
|
| 415 |
-
y_axis="mel",
|
| 416 |
-
ax=ax[i],
|
| 417 |
-
)
|
| 418 |
-
ax[i].set(title=names[i])
|
| 419 |
-
ax[i].yaxis.set_label_text(None)
|
| 420 |
-
ax[i].label_outer()
|
| 421 |
-
fig.savefig(path, bbox_inches="tight")
|
| 422 |
-
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/data/audio_dataset.py
DELETED
|
@@ -1,587 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
"""AudioDataset support. In order to handle a larger number of files
|
| 7 |
-
without having to scan again the folders, we precompute some metadata
|
| 8 |
-
(filename, sample rate, duration), and use that to efficiently sample audio segments.
|
| 9 |
-
"""
|
| 10 |
-
import argparse
|
| 11 |
-
import copy
|
| 12 |
-
from concurrent.futures import ThreadPoolExecutor, Future
|
| 13 |
-
from dataclasses import dataclass, fields
|
| 14 |
-
from contextlib import ExitStack
|
| 15 |
-
from functools import lru_cache
|
| 16 |
-
import gzip
|
| 17 |
-
import json
|
| 18 |
-
import logging
|
| 19 |
-
import os
|
| 20 |
-
from pathlib import Path
|
| 21 |
-
import random
|
| 22 |
-
import sys
|
| 23 |
-
import typing as tp
|
| 24 |
-
|
| 25 |
-
import torch
|
| 26 |
-
import torch.nn.functional as F
|
| 27 |
-
|
| 28 |
-
from .audio import audio_read, audio_info
|
| 29 |
-
from .audio_utils import convert_audio
|
| 30 |
-
from .zip import PathInZip
|
| 31 |
-
|
| 32 |
-
try:
|
| 33 |
-
import dora
|
| 34 |
-
except ImportError:
|
| 35 |
-
dora = None # type: ignore
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
@dataclass(order=True)
|
| 39 |
-
class BaseInfo:
|
| 40 |
-
|
| 41 |
-
@classmethod
|
| 42 |
-
def _dict2fields(cls, dictionary: dict):
|
| 43 |
-
return {
|
| 44 |
-
field.name: dictionary[field.name]
|
| 45 |
-
for field in fields(cls) if field.name in dictionary
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
@classmethod
|
| 49 |
-
def from_dict(cls, dictionary: dict):
|
| 50 |
-
_dictionary = cls._dict2fields(dictionary)
|
| 51 |
-
return cls(**_dictionary)
|
| 52 |
-
|
| 53 |
-
def to_dict(self):
|
| 54 |
-
return {
|
| 55 |
-
field.name: self.__getattribute__(field.name)
|
| 56 |
-
for field in fields(self)
|
| 57 |
-
}
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
@dataclass(order=True)
|
| 61 |
-
class AudioMeta(BaseInfo):
|
| 62 |
-
path: str
|
| 63 |
-
duration: float
|
| 64 |
-
sample_rate: int
|
| 65 |
-
amplitude: tp.Optional[float] = None
|
| 66 |
-
weight: tp.Optional[float] = None
|
| 67 |
-
# info_path is used to load additional information about the audio file that is stored in zip files.
|
| 68 |
-
info_path: tp.Optional[PathInZip] = None
|
| 69 |
-
|
| 70 |
-
@classmethod
|
| 71 |
-
def from_dict(cls, dictionary: dict):
|
| 72 |
-
base = cls._dict2fields(dictionary)
|
| 73 |
-
if 'info_path' in base and base['info_path'] is not None:
|
| 74 |
-
base['info_path'] = PathInZip(base['info_path'])
|
| 75 |
-
return cls(**base)
|
| 76 |
-
|
| 77 |
-
def to_dict(self):
|
| 78 |
-
d = super().to_dict()
|
| 79 |
-
if d['info_path'] is not None:
|
| 80 |
-
d['info_path'] = str(d['info_path'])
|
| 81 |
-
return d
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
@dataclass(order=True)
|
| 85 |
-
class SegmentInfo(BaseInfo):
|
| 86 |
-
meta: AudioMeta
|
| 87 |
-
seek_time: float
|
| 88 |
-
# The following values are given once the audio is processed, e.g.
|
| 89 |
-
# at the target sample rate and target number of channels.
|
| 90 |
-
n_frames: int # actual number of frames without padding
|
| 91 |
-
total_frames: int # total number of frames, padding included
|
| 92 |
-
sample_rate: int # actual sample rate
|
| 93 |
-
channels: int # number of audio channels.
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
|
| 97 |
-
|
| 98 |
-
logger = logging.getLogger(__name__)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
|
| 102 |
-
"""AudioMeta from a path to an audio file.
|
| 103 |
-
|
| 104 |
-
Args:
|
| 105 |
-
file_path (str): Resolved path of valid audio file.
|
| 106 |
-
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
| 107 |
-
Returns:
|
| 108 |
-
AudioMeta: Audio file path and its metadata.
|
| 109 |
-
"""
|
| 110 |
-
info = audio_info(file_path)
|
| 111 |
-
amplitude: tp.Optional[float] = None
|
| 112 |
-
if not minimal:
|
| 113 |
-
wav, sr = audio_read(file_path)
|
| 114 |
-
amplitude = wav.abs().max().item()
|
| 115 |
-
return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
|
| 119 |
-
"""If Dora is available as a dependency, try to resolve potential relative paths
|
| 120 |
-
in list of AudioMeta. This method is expected to be used when loading meta from file.
|
| 121 |
-
|
| 122 |
-
Args:
|
| 123 |
-
m (AudioMeta): Audio meta to resolve.
|
| 124 |
-
fast (bool): If True, uses a really fast check for determining if a file
|
| 125 |
-
is already absolute or not. Only valid on Linux/Mac.
|
| 126 |
-
Returns:
|
| 127 |
-
AudioMeta: Audio meta with resolved path.
|
| 128 |
-
"""
|
| 129 |
-
def is_abs(m):
|
| 130 |
-
if fast:
|
| 131 |
-
return str(m)[0] == '/'
|
| 132 |
-
else:
|
| 133 |
-
os.path.isabs(str(m))
|
| 134 |
-
|
| 135 |
-
if not dora:
|
| 136 |
-
return m
|
| 137 |
-
|
| 138 |
-
if not is_abs(m.path):
|
| 139 |
-
m.path = dora.git_save.to_absolute_path(m.path)
|
| 140 |
-
if m.info_path is not None and not is_abs(m.info_path.zip_path):
|
| 141 |
-
m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
|
| 142 |
-
return m
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def find_audio_files(path: tp.Union[Path, str],
|
| 146 |
-
exts: tp.List[str] = DEFAULT_EXTS,
|
| 147 |
-
resolve: bool = True,
|
| 148 |
-
minimal: bool = True,
|
| 149 |
-
progress: bool = False,
|
| 150 |
-
workers: int = 0) -> tp.List[AudioMeta]:
|
| 151 |
-
"""Build a list of AudioMeta from a given path,
|
| 152 |
-
collecting relevant audio files and fetching meta info.
|
| 153 |
-
|
| 154 |
-
Args:
|
| 155 |
-
path (str or Path): Path to folder containing audio files.
|
| 156 |
-
exts (list of str): List of file extensions to consider for audio files.
|
| 157 |
-
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
| 158 |
-
progress (bool): Whether to log progress on audio files collection.
|
| 159 |
-
workers (int): number of parallel workers, if 0, use only the current thread.
|
| 160 |
-
Returns:
|
| 161 |
-
list of AudioMeta: List of audio file path and its metadata.
|
| 162 |
-
"""
|
| 163 |
-
audio_files = []
|
| 164 |
-
futures: tp.List[Future] = []
|
| 165 |
-
pool: tp.Optional[ThreadPoolExecutor] = None
|
| 166 |
-
with ExitStack() as stack:
|
| 167 |
-
if workers > 0:
|
| 168 |
-
pool = ThreadPoolExecutor(workers)
|
| 169 |
-
stack.enter_context(pool)
|
| 170 |
-
|
| 171 |
-
if progress:
|
| 172 |
-
print("Finding audio files...")
|
| 173 |
-
for root, folders, files in os.walk(path, followlinks=True):
|
| 174 |
-
for file in files:
|
| 175 |
-
full_path = Path(root) / file
|
| 176 |
-
if full_path.suffix.lower() in exts:
|
| 177 |
-
audio_files.append(full_path)
|
| 178 |
-
if pool is not None:
|
| 179 |
-
futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
|
| 180 |
-
if progress:
|
| 181 |
-
print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
|
| 182 |
-
|
| 183 |
-
if progress:
|
| 184 |
-
print("Getting audio metadata...")
|
| 185 |
-
meta: tp.List[AudioMeta] = []
|
| 186 |
-
for idx, file_path in enumerate(audio_files):
|
| 187 |
-
try:
|
| 188 |
-
if pool is None:
|
| 189 |
-
m = _get_audio_meta(str(file_path), minimal)
|
| 190 |
-
else:
|
| 191 |
-
m = futures[idx].result()
|
| 192 |
-
if resolve:
|
| 193 |
-
m = _resolve_audio_meta(m)
|
| 194 |
-
except Exception as err:
|
| 195 |
-
print("Error with", str(file_path), err, file=sys.stderr)
|
| 196 |
-
continue
|
| 197 |
-
meta.append(m)
|
| 198 |
-
if progress:
|
| 199 |
-
print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
|
| 200 |
-
meta.sort()
|
| 201 |
-
return meta
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
def load_audio_meta(path: tp.Union[str, Path],
|
| 205 |
-
resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
|
| 206 |
-
"""Load list of AudioMeta from an optionally compressed json file.
|
| 207 |
-
|
| 208 |
-
Args:
|
| 209 |
-
path (str or Path): Path to JSON file.
|
| 210 |
-
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
|
| 211 |
-
fast (bool): activates some tricks to make things faster.
|
| 212 |
-
Returns:
|
| 213 |
-
list of AudioMeta: List of audio file path and its total duration.
|
| 214 |
-
"""
|
| 215 |
-
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
| 216 |
-
with open_fn(path, 'rb') as fp: # type: ignore
|
| 217 |
-
lines = fp.readlines()
|
| 218 |
-
meta = []
|
| 219 |
-
for line in lines:
|
| 220 |
-
d = json.loads(line)
|
| 221 |
-
m = AudioMeta.from_dict(d)
|
| 222 |
-
if resolve:
|
| 223 |
-
m = _resolve_audio_meta(m, fast=fast)
|
| 224 |
-
meta.append(m)
|
| 225 |
-
return meta
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
|
| 229 |
-
"""Save the audio metadata to the file pointer as json.
|
| 230 |
-
|
| 231 |
-
Args:
|
| 232 |
-
path (str or Path): Path to JSON file.
|
| 233 |
-
metadata (list of BaseAudioMeta): List of audio meta to save.
|
| 234 |
-
"""
|
| 235 |
-
Path(path).parent.mkdir(exist_ok=True, parents=True)
|
| 236 |
-
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
| 237 |
-
with open_fn(path, 'wb') as fp: # type: ignore
|
| 238 |
-
for m in meta:
|
| 239 |
-
json_str = json.dumps(m.to_dict()) + '\n'
|
| 240 |
-
json_bytes = json_str.encode('utf-8')
|
| 241 |
-
fp.write(json_bytes)
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
class AudioDataset:
|
| 245 |
-
"""Base audio dataset.
|
| 246 |
-
|
| 247 |
-
The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
|
| 248 |
-
and potentially additional information, by creating random segments from the list of audio
|
| 249 |
-
files referenced in the metadata and applying minimal data pre-processing such as resampling,
|
| 250 |
-
mixing of channels, padding, etc.
|
| 251 |
-
|
| 252 |
-
If no segment_duration value is provided, the AudioDataset will return the full wav for each
|
| 253 |
-
audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
|
| 254 |
-
duration, applying padding if required.
|
| 255 |
-
|
| 256 |
-
By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
|
| 257 |
-
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
|
| 258 |
-
original audio meta.
|
| 259 |
-
|
| 260 |
-
Note that you can call `start_epoch(epoch)` in order to get
|
| 261 |
-
a deterministic "randomization" for `shuffle=True`.
|
| 262 |
-
For a given epoch and dataset index, this will always return the same extract.
|
| 263 |
-
You can get back some diversity by setting the `shuffle_seed` param.
|
| 264 |
-
|
| 265 |
-
Args:
|
| 266 |
-
meta (list of AudioMeta): List of audio files metadata.
|
| 267 |
-
segment_duration (float, optional): Optional segment duration of audio to load.
|
| 268 |
-
If not specified, the dataset will load the full audio segment from the file.
|
| 269 |
-
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
|
| 270 |
-
sample_rate (int): Target sample rate of the loaded audio samples.
|
| 271 |
-
channels (int): Target number of channels of the loaded audio samples.
|
| 272 |
-
sample_on_duration (bool): Set to `True` to sample segments with probability
|
| 273 |
-
dependent on audio file duration. This is only used if `segment_duration` is provided.
|
| 274 |
-
sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
|
| 275 |
-
`AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
|
| 276 |
-
of the file duration and file weight. This is only used if `segment_duration` is provided.
|
| 277 |
-
min_segment_ratio (float): Minimum segment ratio to use when the audio file
|
| 278 |
-
is shorter than the desired segment.
|
| 279 |
-
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
|
| 280 |
-
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
|
| 281 |
-
min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
|
| 282 |
-
audio shorter than this will be filtered out.
|
| 283 |
-
max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
|
| 284 |
-
audio longer than this will be filtered out.
|
| 285 |
-
shuffle_seed (int): can be used to further randomize
|
| 286 |
-
load_wav (bool): if False, skip loading the wav but returns a tensor of 0
|
| 287 |
-
with the expected segment_duration (which must be provided if load_wav is False).
|
| 288 |
-
permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
|
| 289 |
-
are False. Will ensure a permutation on files when going through the dataset.
|
| 290 |
-
In that case the epoch number must be provided in order for the model
|
| 291 |
-
to continue the permutation across epochs. In that case, it is assumed
|
| 292 |
-
that `num_samples = total_batch_size * num_updates_per_epoch`, with
|
| 293 |
-
`total_batch_size` the overall batch size accounting for all gpus.
|
| 294 |
-
"""
|
| 295 |
-
def __init__(self,
|
| 296 |
-
meta: tp.List[AudioMeta],
|
| 297 |
-
segment_duration: tp.Optional[float] = None,
|
| 298 |
-
shuffle: bool = True,
|
| 299 |
-
num_samples: int = 10_000,
|
| 300 |
-
sample_rate: int = 48_000,
|
| 301 |
-
channels: int = 2,
|
| 302 |
-
pad: bool = True,
|
| 303 |
-
sample_on_duration: bool = True,
|
| 304 |
-
sample_on_weight: bool = True,
|
| 305 |
-
min_segment_ratio: float = 0.5,
|
| 306 |
-
max_read_retry: int = 10,
|
| 307 |
-
return_info: bool = False,
|
| 308 |
-
min_audio_duration: tp.Optional[float] = None,
|
| 309 |
-
max_audio_duration: tp.Optional[float] = None,
|
| 310 |
-
shuffle_seed: int = 0,
|
| 311 |
-
load_wav: bool = True,
|
| 312 |
-
permutation_on_files: bool = False,
|
| 313 |
-
):
|
| 314 |
-
assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
|
| 315 |
-
assert segment_duration is None or segment_duration > 0
|
| 316 |
-
assert segment_duration is None or min_segment_ratio >= 0
|
| 317 |
-
self.segment_duration = segment_duration
|
| 318 |
-
self.min_segment_ratio = min_segment_ratio
|
| 319 |
-
self.max_audio_duration = max_audio_duration
|
| 320 |
-
self.min_audio_duration = min_audio_duration
|
| 321 |
-
if self.min_audio_duration is not None and self.max_audio_duration is not None:
|
| 322 |
-
assert self.min_audio_duration <= self.max_audio_duration
|
| 323 |
-
self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
|
| 324 |
-
assert len(self.meta) # Fail fast if all data has been filtered.
|
| 325 |
-
self.total_duration = sum(d.duration for d in self.meta)
|
| 326 |
-
|
| 327 |
-
if segment_duration is None:
|
| 328 |
-
num_samples = len(self.meta)
|
| 329 |
-
self.num_samples = num_samples
|
| 330 |
-
self.shuffle = shuffle
|
| 331 |
-
self.sample_rate = sample_rate
|
| 332 |
-
self.channels = channels
|
| 333 |
-
self.pad = pad
|
| 334 |
-
self.sample_on_weight = sample_on_weight
|
| 335 |
-
self.sample_on_duration = sample_on_duration
|
| 336 |
-
self.sampling_probabilities = self._get_sampling_probabilities()
|
| 337 |
-
self.max_read_retry = max_read_retry
|
| 338 |
-
self.return_info = return_info
|
| 339 |
-
self.shuffle_seed = shuffle_seed
|
| 340 |
-
self.current_epoch: tp.Optional[int] = None
|
| 341 |
-
self.load_wav = load_wav
|
| 342 |
-
if not load_wav:
|
| 343 |
-
assert segment_duration is not None
|
| 344 |
-
self.permutation_on_files = permutation_on_files
|
| 345 |
-
if permutation_on_files:
|
| 346 |
-
assert not self.sample_on_duration
|
| 347 |
-
assert not self.sample_on_weight
|
| 348 |
-
assert self.shuffle
|
| 349 |
-
|
| 350 |
-
def start_epoch(self, epoch: int):
|
| 351 |
-
self.current_epoch = epoch
|
| 352 |
-
|
| 353 |
-
def __len__(self):
|
| 354 |
-
return self.num_samples
|
| 355 |
-
|
| 356 |
-
def _get_sampling_probabilities(self, normalized: bool = True):
|
| 357 |
-
"""Return the sampling probabilities for each file inside `self.meta`."""
|
| 358 |
-
scores: tp.List[float] = []
|
| 359 |
-
for file_meta in self.meta:
|
| 360 |
-
score = 1.
|
| 361 |
-
if self.sample_on_weight and file_meta.weight is not None:
|
| 362 |
-
score *= file_meta.weight
|
| 363 |
-
if self.sample_on_duration:
|
| 364 |
-
score *= file_meta.duration
|
| 365 |
-
scores.append(score)
|
| 366 |
-
probabilities = torch.tensor(scores)
|
| 367 |
-
if normalized:
|
| 368 |
-
probabilities /= probabilities.sum()
|
| 369 |
-
return probabilities
|
| 370 |
-
|
| 371 |
-
@staticmethod
|
| 372 |
-
@lru_cache(16)
|
| 373 |
-
def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
|
| 374 |
-
# Used to keep the most recent files permutation in memory implicitely.
|
| 375 |
-
# will work unless someone is using a lot of Datasets in parallel.
|
| 376 |
-
rng = torch.Generator()
|
| 377 |
-
rng.manual_seed(base_seed + permutation_index)
|
| 378 |
-
return torch.randperm(num_files, generator=rng)
|
| 379 |
-
|
| 380 |
-
def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
|
| 381 |
-
"""Sample a given file from `self.meta`. Can be overridden in subclasses.
|
| 382 |
-
This is only called if `segment_duration` is not None.
|
| 383 |
-
|
| 384 |
-
You must use the provided random number generator `rng` for reproducibility.
|
| 385 |
-
You can further make use of the index accessed.
|
| 386 |
-
"""
|
| 387 |
-
if self.permutation_on_files:
|
| 388 |
-
assert self.current_epoch is not None
|
| 389 |
-
total_index = self.current_epoch * len(self) + index
|
| 390 |
-
permutation_index = total_index // len(self.meta)
|
| 391 |
-
relative_index = total_index % len(self.meta)
|
| 392 |
-
permutation = AudioDataset._get_file_permutation(
|
| 393 |
-
len(self.meta), permutation_index, self.shuffle_seed)
|
| 394 |
-
file_index = permutation[relative_index]
|
| 395 |
-
return self.meta[file_index]
|
| 396 |
-
|
| 397 |
-
if not self.sample_on_weight and not self.sample_on_duration:
|
| 398 |
-
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
|
| 399 |
-
else:
|
| 400 |
-
file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
|
| 401 |
-
|
| 402 |
-
return self.meta[file_index]
|
| 403 |
-
|
| 404 |
-
def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
|
| 405 |
-
# Override this method in subclass if needed.
|
| 406 |
-
if self.load_wav:
|
| 407 |
-
return audio_read(path, seek_time, duration, pad=False)
|
| 408 |
-
else:
|
| 409 |
-
assert self.segment_duration is not None
|
| 410 |
-
n_frames = int(self.sample_rate * self.segment_duration)
|
| 411 |
-
return torch.zeros(self.channels, n_frames), self.sample_rate
|
| 412 |
-
|
| 413 |
-
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
|
| 414 |
-
if self.segment_duration is None:
|
| 415 |
-
file_meta = self.meta[index]
|
| 416 |
-
out, sr = audio_read(file_meta.path)
|
| 417 |
-
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
| 418 |
-
n_frames = out.shape[-1]
|
| 419 |
-
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
|
| 420 |
-
sample_rate=self.sample_rate, channels=out.shape[0])
|
| 421 |
-
else:
|
| 422 |
-
rng = torch.Generator()
|
| 423 |
-
if self.shuffle:
|
| 424 |
-
# We use index, plus extra randomness, either totally random if we don't know the epoch.
|
| 425 |
-
# otherwise we make use of the epoch number and optional shuffle_seed.
|
| 426 |
-
if self.current_epoch is None:
|
| 427 |
-
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
|
| 428 |
-
else:
|
| 429 |
-
rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
|
| 430 |
-
else:
|
| 431 |
-
# We only use index
|
| 432 |
-
rng.manual_seed(index)
|
| 433 |
-
|
| 434 |
-
for retry in range(self.max_read_retry):
|
| 435 |
-
file_meta = self.sample_file(index, rng)
|
| 436 |
-
# We add some variance in the file position even if audio file is smaller than segment
|
| 437 |
-
# without ending up with empty segments
|
| 438 |
-
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
|
| 439 |
-
seek_time = torch.rand(1, generator=rng).item() * max_seek
|
| 440 |
-
try:
|
| 441 |
-
out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
|
| 442 |
-
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
| 443 |
-
n_frames = out.shape[-1]
|
| 444 |
-
target_frames = int(self.segment_duration * self.sample_rate)
|
| 445 |
-
if self.pad:
|
| 446 |
-
out = F.pad(out, (0, target_frames - n_frames))
|
| 447 |
-
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
|
| 448 |
-
sample_rate=self.sample_rate, channels=out.shape[0])
|
| 449 |
-
except Exception as exc:
|
| 450 |
-
logger.warning("Error opening file %s: %r", file_meta.path, exc)
|
| 451 |
-
if retry == self.max_read_retry - 1:
|
| 452 |
-
raise
|
| 453 |
-
else:
|
| 454 |
-
break
|
| 455 |
-
|
| 456 |
-
if self.return_info:
|
| 457 |
-
# Returns the wav and additional information on the wave segment
|
| 458 |
-
return out, segment_info
|
| 459 |
-
else:
|
| 460 |
-
return out
|
| 461 |
-
|
| 462 |
-
def collater(self, samples):
|
| 463 |
-
"""The collater function has to be provided to the dataloader
|
| 464 |
-
if AudioDataset has return_info=True in order to properly collate
|
| 465 |
-
the samples of a batch.
|
| 466 |
-
"""
|
| 467 |
-
if self.segment_duration is None and len(samples) > 1:
|
| 468 |
-
assert self.pad, "Must allow padding when batching examples of different durations."
|
| 469 |
-
|
| 470 |
-
# In this case the audio reaching the collater is of variable length as segment_duration=None.
|
| 471 |
-
to_pad = self.segment_duration is None and self.pad
|
| 472 |
-
if to_pad:
|
| 473 |
-
max_len = max([wav.shape[-1] for wav, _ in samples])
|
| 474 |
-
|
| 475 |
-
def _pad_wav(wav):
|
| 476 |
-
return F.pad(wav, (0, max_len - wav.shape[-1]))
|
| 477 |
-
|
| 478 |
-
if self.return_info:
|
| 479 |
-
if len(samples) > 0:
|
| 480 |
-
assert len(samples[0]) == 2
|
| 481 |
-
assert isinstance(samples[0][0], torch.Tensor)
|
| 482 |
-
assert isinstance(samples[0][1], SegmentInfo)
|
| 483 |
-
|
| 484 |
-
wavs = [wav for wav, _ in samples]
|
| 485 |
-
segment_infos = [copy.deepcopy(info) for _, info in samples]
|
| 486 |
-
|
| 487 |
-
if to_pad:
|
| 488 |
-
# Each wav could be of a different duration as they are not segmented.
|
| 489 |
-
for i in range(len(samples)):
|
| 490 |
-
# Determines the total length of the signal with padding, so we update here as we pad.
|
| 491 |
-
segment_infos[i].total_frames = max_len
|
| 492 |
-
wavs[i] = _pad_wav(wavs[i])
|
| 493 |
-
|
| 494 |
-
wav = torch.stack(wavs)
|
| 495 |
-
return wav, segment_infos
|
| 496 |
-
else:
|
| 497 |
-
assert isinstance(samples[0], torch.Tensor)
|
| 498 |
-
if to_pad:
|
| 499 |
-
samples = [_pad_wav(s) for s in samples]
|
| 500 |
-
return torch.stack(samples)
|
| 501 |
-
|
| 502 |
-
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
| 503 |
-
"""Filters out audio files with audio durations that will not allow to sample examples from them."""
|
| 504 |
-
orig_len = len(meta)
|
| 505 |
-
|
| 506 |
-
# Filter data that is too short.
|
| 507 |
-
if self.min_audio_duration is not None:
|
| 508 |
-
meta = [m for m in meta if m.duration >= self.min_audio_duration]
|
| 509 |
-
|
| 510 |
-
# Filter data that is too long.
|
| 511 |
-
if self.max_audio_duration is not None:
|
| 512 |
-
meta = [m for m in meta if m.duration <= self.max_audio_duration]
|
| 513 |
-
|
| 514 |
-
filtered_len = len(meta)
|
| 515 |
-
removed_percentage = 100*(1-float(filtered_len)/orig_len)
|
| 516 |
-
msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
|
| 517 |
-
if removed_percentage < 10:
|
| 518 |
-
logging.debug(msg)
|
| 519 |
-
else:
|
| 520 |
-
logging.warning(msg)
|
| 521 |
-
return meta
|
| 522 |
-
|
| 523 |
-
@classmethod
|
| 524 |
-
def from_meta(cls, root: tp.Union[str, Path], **kwargs):
|
| 525 |
-
"""Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
|
| 526 |
-
|
| 527 |
-
Args:
|
| 528 |
-
root (str or Path): Path to root folder containing audio files.
|
| 529 |
-
kwargs: Additional keyword arguments for the AudioDataset.
|
| 530 |
-
"""
|
| 531 |
-
root = Path(root)
|
| 532 |
-
if root.is_dir():
|
| 533 |
-
if (root / 'data.jsonl').exists():
|
| 534 |
-
root = root / 'data.jsonl'
|
| 535 |
-
elif (root / 'data.jsonl.gz').exists():
|
| 536 |
-
root = root / 'data.jsonl.gz'
|
| 537 |
-
else:
|
| 538 |
-
raise ValueError("Don't know where to read metadata from in the dir. "
|
| 539 |
-
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
|
| 540 |
-
meta = load_audio_meta(root)
|
| 541 |
-
return cls(meta, **kwargs)
|
| 542 |
-
|
| 543 |
-
@classmethod
|
| 544 |
-
def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
|
| 545 |
-
exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
|
| 546 |
-
"""Instantiate AudioDataset from a path containing (possibly nested) audio files.
|
| 547 |
-
|
| 548 |
-
Args:
|
| 549 |
-
root (str or Path): Path to root folder containing audio files.
|
| 550 |
-
minimal_meta (bool): Whether to only load minimal metadata or not.
|
| 551 |
-
exts (list of str): Extensions for audio files.
|
| 552 |
-
kwargs: Additional keyword arguments for the AudioDataset.
|
| 553 |
-
"""
|
| 554 |
-
root = Path(root)
|
| 555 |
-
if root.is_file():
|
| 556 |
-
meta = load_audio_meta(root, resolve=True)
|
| 557 |
-
else:
|
| 558 |
-
meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
|
| 559 |
-
return cls(meta, **kwargs)
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
def main():
|
| 563 |
-
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
|
| 564 |
-
parser = argparse.ArgumentParser(
|
| 565 |
-
prog='audio_dataset',
|
| 566 |
-
description='Generate .jsonl files by scanning a folder.')
|
| 567 |
-
parser.add_argument('root', help='Root folder with all the audio files')
|
| 568 |
-
parser.add_argument('output_meta_file',
|
| 569 |
-
help='Output file to store the metadata, ')
|
| 570 |
-
parser.add_argument('--complete',
|
| 571 |
-
action='store_false', dest='minimal', default=True,
|
| 572 |
-
help='Retrieve all metadata, even the one that are expansive '
|
| 573 |
-
'to compute (e.g. normalization).')
|
| 574 |
-
parser.add_argument('--resolve',
|
| 575 |
-
action='store_true', default=False,
|
| 576 |
-
help='Resolve the paths to be absolute and with no symlinks.')
|
| 577 |
-
parser.add_argument('--workers',
|
| 578 |
-
default=10, type=int,
|
| 579 |
-
help='Number of workers.')
|
| 580 |
-
args = parser.parse_args()
|
| 581 |
-
meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
|
| 582 |
-
resolve=args.resolve, minimal=args.minimal, workers=args.workers)
|
| 583 |
-
save_audio_meta(args.output_meta_file, meta)
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
if __name__ == '__main__':
|
| 587 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/data/audio_utils.py
DELETED
|
@@ -1,296 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
"""Various utilities for audio convertion (pcm format, sample rate and channels),
|
| 7 |
-
and volume normalization."""
|
| 8 |
-
import sys
|
| 9 |
-
import typing as tp
|
| 10 |
-
|
| 11 |
-
import julius
|
| 12 |
-
import torch
|
| 13 |
-
import torchaudio
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
|
| 17 |
-
"""Convert audio to the given number of channels.
|
| 18 |
-
|
| 19 |
-
Args:
|
| 20 |
-
wav (torch.Tensor): Audio wave of shape [B, C, T].
|
| 21 |
-
channels (int): Expected number of channels as output.
|
| 22 |
-
Returns:
|
| 23 |
-
torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
|
| 24 |
-
"""
|
| 25 |
-
*shape, src_channels, length = wav.shape
|
| 26 |
-
if src_channels == channels:
|
| 27 |
-
pass
|
| 28 |
-
elif channels == 1:
|
| 29 |
-
# Case 1:
|
| 30 |
-
# The caller asked 1-channel audio, and the stream has multiple
|
| 31 |
-
# channels, downmix all channels.
|
| 32 |
-
wav = wav.mean(dim=-2, keepdim=True)
|
| 33 |
-
elif src_channels == 1:
|
| 34 |
-
# Case 2:
|
| 35 |
-
# The caller asked for multiple channels, but the input file has
|
| 36 |
-
# a single channel, replicate the audio over all channels.
|
| 37 |
-
wav = wav.expand(*shape, channels, length)
|
| 38 |
-
elif src_channels >= channels:
|
| 39 |
-
# Case 3:
|
| 40 |
-
# The caller asked for multiple channels, and the input file has
|
| 41 |
-
# more channels than requested. In that case return the first channels.
|
| 42 |
-
wav = wav[..., :channels, :]
|
| 43 |
-
else:
|
| 44 |
-
# Case 4: What is a reasonable choice here?
|
| 45 |
-
raise ValueError('The audio file has less channels than requested but is not mono.')
|
| 46 |
-
return wav
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def convert_audio(wav: torch.Tensor, from_rate: float,
|
| 50 |
-
to_rate: float, to_channels: int) -> torch.Tensor:
|
| 51 |
-
"""Convert audio to new sample rate and number of audio channels.
|
| 52 |
-
"""
|
| 53 |
-
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
|
| 54 |
-
wav = convert_audio_channels(wav, to_channels)
|
| 55 |
-
return wav
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
|
| 59 |
-
loudness_compressor: bool = False, energy_floor: float = 2e-3):
|
| 60 |
-
"""Normalize an input signal to a user loudness in dB LKFS.
|
| 61 |
-
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
|
| 62 |
-
|
| 63 |
-
Args:
|
| 64 |
-
wav (torch.Tensor): Input multichannel audio data.
|
| 65 |
-
sample_rate (int): Sample rate.
|
| 66 |
-
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
|
| 67 |
-
loudness_compressor (bool): Uses tanh for soft clipping.
|
| 68 |
-
energy_floor (float): anything below that RMS level will not be rescaled.
|
| 69 |
-
Returns:
|
| 70 |
-
output (torch.Tensor): Loudness normalized output data.
|
| 71 |
-
"""
|
| 72 |
-
energy = wav.pow(2).mean().sqrt().item()
|
| 73 |
-
if energy < energy_floor:
|
| 74 |
-
return wav
|
| 75 |
-
transform = torchaudio.transforms.Loudness(sample_rate)
|
| 76 |
-
input_loudness_db = transform(wav).item()
|
| 77 |
-
# calculate the gain needed to scale to the desired loudness level
|
| 78 |
-
delta_loudness = -loudness_headroom_db - input_loudness_db
|
| 79 |
-
gain = 10.0 ** (delta_loudness / 20.0)
|
| 80 |
-
output = gain * wav
|
| 81 |
-
if loudness_compressor:
|
| 82 |
-
output = torch.tanh(output)
|
| 83 |
-
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
|
| 84 |
-
return output
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
|
| 88 |
-
"""Utility function to clip the audio with logging if specified."""
|
| 89 |
-
max_scale = wav.abs().max()
|
| 90 |
-
if log_clipping and max_scale > 1:
|
| 91 |
-
clamp_prob = (wav.abs() > 1).float().mean().item()
|
| 92 |
-
print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
|
| 93 |
-
clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
|
| 94 |
-
wav.clamp_(-1, 1)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
|
| 98 |
-
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
| 99 |
-
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
| 100 |
-
loudness_compressor: bool = False, log_clipping: bool = False,
|
| 101 |
-
sample_rate: tp.Optional[int] = None,
|
| 102 |
-
stem_name: tp.Optional[str] = None) -> torch.Tensor:
|
| 103 |
-
"""Normalize the audio according to the prescribed strategy (see after).
|
| 104 |
-
|
| 105 |
-
Args:
|
| 106 |
-
wav (torch.Tensor): Audio data.
|
| 107 |
-
normalize (bool): if `True` (default), normalizes according to the prescribed
|
| 108 |
-
strategy (see after). If `False`, the strategy is only used in case clipping
|
| 109 |
-
would happen.
|
| 110 |
-
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
| 111 |
-
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
| 112 |
-
with extra headroom to avoid clipping. 'clip' just clips.
|
| 113 |
-
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
| 114 |
-
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
| 115 |
-
than the `peak_clip` one to avoid further clipping.
|
| 116 |
-
loudness_headroom_db (float): Target loudness for loudness normalization.
|
| 117 |
-
loudness_compressor (bool): If True, uses tanh based soft clipping.
|
| 118 |
-
log_clipping (bool): If True, basic logging on stderr when clipping still
|
| 119 |
-
occurs despite strategy (only for 'rms').
|
| 120 |
-
sample_rate (int): Sample rate for the audio data (required for loudness).
|
| 121 |
-
stem_name (Optional[str]): Stem name for clipping logging.
|
| 122 |
-
Returns:
|
| 123 |
-
torch.Tensor: Normalized audio.
|
| 124 |
-
"""
|
| 125 |
-
scale_peak = 10 ** (-peak_clip_headroom_db / 20)
|
| 126 |
-
scale_rms = 10 ** (-rms_headroom_db / 20)
|
| 127 |
-
if strategy == 'peak':
|
| 128 |
-
rescaling = (scale_peak / wav.abs().max())
|
| 129 |
-
if normalize or rescaling < 1:
|
| 130 |
-
wav = wav * rescaling
|
| 131 |
-
elif strategy == 'clip':
|
| 132 |
-
wav = wav.clamp(-scale_peak, scale_peak)
|
| 133 |
-
elif strategy == 'rms':
|
| 134 |
-
mono = wav.mean(dim=0)
|
| 135 |
-
rescaling = scale_rms / mono.pow(2).mean().sqrt()
|
| 136 |
-
if normalize or rescaling < 1:
|
| 137 |
-
wav = wav * rescaling
|
| 138 |
-
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
| 139 |
-
elif strategy == 'loudness':
|
| 140 |
-
assert sample_rate is not None, "Loudness normalization requires sample rate."
|
| 141 |
-
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
|
| 142 |
-
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
| 143 |
-
else:
|
| 144 |
-
assert wav.abs().max() < 1
|
| 145 |
-
assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
|
| 146 |
-
return wav
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
| 150 |
-
"""Convert audio to float 32 bits PCM format.
|
| 151 |
-
"""
|
| 152 |
-
if wav.dtype.is_floating_point:
|
| 153 |
-
return wav
|
| 154 |
-
elif wav.dtype == torch.int16:
|
| 155 |
-
return wav.float() / 2**15
|
| 156 |
-
elif wav.dtype == torch.int32:
|
| 157 |
-
return wav.float() / 2**31
|
| 158 |
-
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
|
| 162 |
-
"""Convert audio to int 16 bits PCM format.
|
| 163 |
-
|
| 164 |
-
..Warning:: There exist many formula for doing this conversion. None are perfect
|
| 165 |
-
due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
|
| 166 |
-
or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
|
| 167 |
-
it is possible that `i16_pcm(f32_pcm)) != Identity`.
|
| 168 |
-
"""
|
| 169 |
-
if wav.dtype.is_floating_point:
|
| 170 |
-
assert wav.abs().max() <= 1
|
| 171 |
-
candidate = (wav * 2 ** 15).round()
|
| 172 |
-
if candidate.max() >= 2 ** 15: # clipping would occur
|
| 173 |
-
candidate = (wav * (2 ** 15 - 1)).round()
|
| 174 |
-
return candidate.short()
|
| 175 |
-
else:
|
| 176 |
-
assert wav.dtype == torch.int16
|
| 177 |
-
return wav
|
| 178 |
-
|
| 179 |
-
def apply_tafade(audio: torch.Tensor, sample_rate, duration=3.0, out=True, start=True, shape: str = "linear", stem_name: tp.Optional[str] = None) -> torch.Tensor:
|
| 180 |
-
"""
|
| 181 |
-
Apply fade-in and/or fade-out effects to the audio tensor.
|
| 182 |
-
|
| 183 |
-
Args:
|
| 184 |
-
audio (torch.Tensor): The input audio tensor of shape (C, L).
|
| 185 |
-
sample_rate (int): The sample rate of the audio.
|
| 186 |
-
duration (float, optional): The duration of the fade in seconds. Defaults to 3.0.
|
| 187 |
-
out (bool, optional): Determines whether to apply fade-in (False) or fade-out (True) effect. Defaults to True.
|
| 188 |
-
start (bool, optional): Determines whether the fade is applied to the beginning (True) or end (False) of the audio. Defaults to True.
|
| 189 |
-
shape (str, optional): The shape of the fade. Must be one of: "quarter_sine", "half_sine", "linear", "logarithmic", "exponential". Defaults to "linear".
|
| 190 |
-
|
| 191 |
-
Returns:
|
| 192 |
-
torch.Tensor: The audio tensor with the fade effect applied.
|
| 193 |
-
|
| 194 |
-
"""
|
| 195 |
-
fade_samples = int(sample_rate * duration) # Number of samples for the fade duration
|
| 196 |
-
|
| 197 |
-
# Create the fade transform
|
| 198 |
-
fade_transform = torchaudio.transforms.Fade(fade_in_len=0, fade_out_len=0, fade_shape=shape)
|
| 199 |
-
|
| 200 |
-
if out:
|
| 201 |
-
fade_transform.fade_out_len = fade_samples
|
| 202 |
-
else:
|
| 203 |
-
fade_transform.fade_in_len = fade_samples
|
| 204 |
-
|
| 205 |
-
# Select the portion of the audio to apply the fade
|
| 206 |
-
if start:
|
| 207 |
-
audio_fade_section = audio[:, :fade_samples]
|
| 208 |
-
else:
|
| 209 |
-
audio_fade_section = audio[:, -fade_samples:]
|
| 210 |
-
|
| 211 |
-
# Apply the fade transform to the audio section
|
| 212 |
-
audio_faded = fade_transform(audio)
|
| 213 |
-
|
| 214 |
-
# Replace the selected portion of the audio with the faded section
|
| 215 |
-
if start:
|
| 216 |
-
audio_faded[:, :fade_samples] = audio_fade_section
|
| 217 |
-
else:
|
| 218 |
-
audio_faded[:, -fade_samples:] = audio_fade_section
|
| 219 |
-
|
| 220 |
-
wav = normalize_loudness(audio_faded,sample_rate, loudness_headroom_db=18, loudness_compressor=True)
|
| 221 |
-
_clip_wav(wav, log_clipping=False, stem_name=stem_name)
|
| 222 |
-
return wav
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
def apply_fade(audio: torch.Tensor, sample_rate, duration=3.0, out=True, start=True, curve_start:float=0.0, curve_end:float=1.0, current_device:str="cpu", stem_name: tp.Optional[str] = None) -> torch.Tensor:
|
| 226 |
-
"""
|
| 227 |
-
Apply fade-in and/or fade-out effects to the audio tensor.
|
| 228 |
-
|
| 229 |
-
Args:
|
| 230 |
-
audio (torch.Tensor): The input audio tensor of shape (C, L).
|
| 231 |
-
sample_rate (int): The sample rate of the audio.
|
| 232 |
-
duration (float, optional): The duration of the fade in seconds. Defaults to 3.0.
|
| 233 |
-
out (bool, optional): Determines whether to apply fade-in (False) or fade-out (True) effect. Defaults to True.
|
| 234 |
-
start (bool, optional): Determines whether the fade is applied to the beginning (True) or end (False) of the audio. Defaults to True.
|
| 235 |
-
curve_start (float, optional): The starting amplitude of the fade curve. Defaults to 0.0.
|
| 236 |
-
curve_end (float, optional): The ending amplitude of the fade curve. Defaults to 1.0.
|
| 237 |
-
current_device (str, optional): The device on which the fade curve tensor should be created. Defaults to "cpu".
|
| 238 |
-
|
| 239 |
-
Returns:
|
| 240 |
-
torch.Tensor: The audio tensor with the fade effect applied.
|
| 241 |
-
|
| 242 |
-
"""
|
| 243 |
-
fade_samples = int(sample_rate * duration) # Number of samples for the fade duration
|
| 244 |
-
fade_curve = torch.linspace(curve_start, curve_end, fade_samples, device=current_device) # Generate linear fade curve
|
| 245 |
-
|
| 246 |
-
if out:
|
| 247 |
-
fade_curve = fade_curve.flip(0) # Reverse the fade curve for fade out
|
| 248 |
-
|
| 249 |
-
# Select the portion of the audio to apply the fade
|
| 250 |
-
if start:
|
| 251 |
-
audio_fade_section = audio[:, :fade_samples]
|
| 252 |
-
else:
|
| 253 |
-
audio_fade_section = audio[:, -fade_samples:]
|
| 254 |
-
|
| 255 |
-
# Apply the fade curve to the audio section
|
| 256 |
-
audio_faded = audio.clone()
|
| 257 |
-
audio_faded[:, :fade_samples] *= fade_curve.unsqueeze(0)
|
| 258 |
-
audio_faded[:, -fade_samples:] *= fade_curve.unsqueeze(0)
|
| 259 |
-
|
| 260 |
-
# Replace the selected portion of the audio with the faded section
|
| 261 |
-
if start:
|
| 262 |
-
audio_faded[:, :fade_samples] = audio_fade_section
|
| 263 |
-
else:
|
| 264 |
-
audio_faded[:, -fade_samples:] = audio_fade_section
|
| 265 |
-
|
| 266 |
-
wav = normalize_loudness(audio_faded,sample_rate, loudness_headroom_db=18, loudness_compressor=True)
|
| 267 |
-
_clip_wav(wav, log_clipping=False, stem_name=stem_name)
|
| 268 |
-
return wav
|
| 269 |
-
|
| 270 |
-
def apply_splice_effect(waveform1, sample_rate1, waveform2, sample_rate2, overlap):
|
| 271 |
-
# Convert sample rates to integers
|
| 272 |
-
sample_rate1 = int(sample_rate1)
|
| 273 |
-
sample_rate2 = int(sample_rate2)
|
| 274 |
-
|
| 275 |
-
# Convert tensors to mono-channel if needed
|
| 276 |
-
if waveform1.ndim > 2:
|
| 277 |
-
waveform1 = waveform1.mean(dim=1)
|
| 278 |
-
if waveform2.ndim > 2:
|
| 279 |
-
waveform2 = waveform2.mean(dim=1)
|
| 280 |
-
|
| 281 |
-
## Convert tensors to numpy arrays
|
| 282 |
-
#waveform1_np = waveform1.numpy()
|
| 283 |
-
#waveform2_np = waveform2.numpy()
|
| 284 |
-
|
| 285 |
-
# Apply splice effect using torchaudio.sox_effects.apply_effects_tensor
|
| 286 |
-
effects = [
|
| 287 |
-
["splice", f"-q {waveform1},{overlap}"],
|
| 288 |
-
]
|
| 289 |
-
output_waveform, output_sample_rate = torchaudio.sox_effects.apply_effects_tensor(
|
| 290 |
-
torch.cat([waveform1.unsqueeze(0), waveform2.unsqueeze(0)], dim=2),
|
| 291 |
-
sample_rate1,
|
| 292 |
-
effects
|
| 293 |
-
)
|
| 294 |
-
|
| 295 |
-
return output_waveform.squeeze(0), output_sample_rate
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/data/info_audio_dataset.py
DELETED
|
@@ -1,110 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
"""Base classes for the datasets that also provide non-audio metadata,
|
| 7 |
-
e.g. description, text transcription etc.
|
| 8 |
-
"""
|
| 9 |
-
from dataclasses import dataclass
|
| 10 |
-
import logging
|
| 11 |
-
import math
|
| 12 |
-
import re
|
| 13 |
-
import typing as tp
|
| 14 |
-
|
| 15 |
-
import torch
|
| 16 |
-
|
| 17 |
-
from .audio_dataset import AudioDataset, AudioMeta
|
| 18 |
-
from ..environment import AudioCraftEnvironment
|
| 19 |
-
from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
logger = logging.getLogger(__name__)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
|
| 26 |
-
"""Monkey-patch meta to match cluster specificities."""
|
| 27 |
-
meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
|
| 28 |
-
if meta.info_path is not None:
|
| 29 |
-
meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
|
| 30 |
-
return meta
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
| 34 |
-
"""Monkey-patch all meta to match cluster specificities."""
|
| 35 |
-
return [_clusterify_meta(m) for m in meta]
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
@dataclass
|
| 39 |
-
class AudioInfo(SegmentWithAttributes):
|
| 40 |
-
"""Dummy SegmentInfo with empty attributes.
|
| 41 |
-
|
| 42 |
-
The InfoAudioDataset is expected to return metadata that inherits
|
| 43 |
-
from SegmentWithAttributes class and can return conditioning attributes.
|
| 44 |
-
|
| 45 |
-
This basically guarantees all datasets will be compatible with current
|
| 46 |
-
solver that contain conditioners requiring this.
|
| 47 |
-
"""
|
| 48 |
-
audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
|
| 49 |
-
|
| 50 |
-
def to_condition_attributes(self) -> ConditioningAttributes:
|
| 51 |
-
return ConditioningAttributes()
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
class InfoAudioDataset(AudioDataset):
|
| 55 |
-
"""AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
|
| 56 |
-
|
| 57 |
-
See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
|
| 58 |
-
"""
|
| 59 |
-
def __init__(self, meta: tp.List[AudioMeta], **kwargs):
|
| 60 |
-
super().__init__(clusterify_all_meta(meta), **kwargs)
|
| 61 |
-
|
| 62 |
-
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
|
| 63 |
-
if not self.return_info:
|
| 64 |
-
wav = super().__getitem__(index)
|
| 65 |
-
assert isinstance(wav, torch.Tensor)
|
| 66 |
-
return wav
|
| 67 |
-
wav, meta = super().__getitem__(index)
|
| 68 |
-
return wav, AudioInfo(**meta.to_dict())
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
|
| 72 |
-
"""Preprocess a single keyword or possible a list of keywords."""
|
| 73 |
-
if isinstance(value, list):
|
| 74 |
-
return get_keyword_list(value)
|
| 75 |
-
else:
|
| 76 |
-
return get_keyword(value)
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
|
| 80 |
-
"""Preprocess a single keyword."""
|
| 81 |
-
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
| 82 |
-
return None
|
| 83 |
-
else:
|
| 84 |
-
return value.strip()
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
|
| 88 |
-
"""Preprocess a single keyword."""
|
| 89 |
-
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
| 90 |
-
return None
|
| 91 |
-
else:
|
| 92 |
-
return value.strip().lower()
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
|
| 96 |
-
"""Preprocess a list of keywords."""
|
| 97 |
-
if isinstance(values, str):
|
| 98 |
-
values = [v.strip() for v in re.split(r'[,\s]', values)]
|
| 99 |
-
elif isinstance(values, float) and math.isnan(values):
|
| 100 |
-
values = []
|
| 101 |
-
if not isinstance(values, list):
|
| 102 |
-
logger.debug(f"Unexpected keyword list {values}")
|
| 103 |
-
values = [str(values)]
|
| 104 |
-
|
| 105 |
-
kws = [get_keyword(v) for v in values]
|
| 106 |
-
kw_list = [k for k in kws if k is not None]
|
| 107 |
-
if len(kw_list) == 0:
|
| 108 |
-
return None
|
| 109 |
-
else:
|
| 110 |
-
return kw_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/data/zip.py
DELETED
|
@@ -1,76 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
"""Utility for reading some info from inside a zip file.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import typing
|
| 10 |
-
import zipfile
|
| 11 |
-
|
| 12 |
-
from dataclasses import dataclass
|
| 13 |
-
from functools import lru_cache
|
| 14 |
-
from typing_extensions import Literal
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
DEFAULT_SIZE = 32
|
| 18 |
-
MODE = Literal['r', 'w', 'x', 'a']
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
@dataclass(order=True)
|
| 22 |
-
class PathInZip:
|
| 23 |
-
"""Hold a path of file within a zip file.
|
| 24 |
-
|
| 25 |
-
Args:
|
| 26 |
-
path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
|
| 27 |
-
Let's assume there is a zip file /some/location/foo.zip
|
| 28 |
-
and inside of it is a json file located at /data/file1.json,
|
| 29 |
-
Then we expect path = "/some/location/foo.zip:/data/file1.json".
|
| 30 |
-
"""
|
| 31 |
-
|
| 32 |
-
INFO_PATH_SEP = ':'
|
| 33 |
-
zip_path: str
|
| 34 |
-
file_path: str
|
| 35 |
-
|
| 36 |
-
def __init__(self, path: str) -> None:
|
| 37 |
-
split_path = path.split(self.INFO_PATH_SEP)
|
| 38 |
-
assert len(split_path) == 2
|
| 39 |
-
self.zip_path, self.file_path = split_path
|
| 40 |
-
|
| 41 |
-
@classmethod
|
| 42 |
-
def from_paths(cls, zip_path: str, file_path: str):
|
| 43 |
-
return cls(zip_path + cls.INFO_PATH_SEP + file_path)
|
| 44 |
-
|
| 45 |
-
def __str__(self) -> str:
|
| 46 |
-
return self.zip_path + self.INFO_PATH_SEP + self.file_path
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def _open_zip(path: str, mode: MODE = 'r'):
|
| 50 |
-
return zipfile.ZipFile(path, mode)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def set_zip_cache_size(max_size: int):
|
| 57 |
-
"""Sets the maximal LRU caching for zip file opening.
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
max_size (int): the maximal LRU cache.
|
| 61 |
-
"""
|
| 62 |
-
global _cached_open_zip
|
| 63 |
-
_cached_open_zip = lru_cache(max_size)(_open_zip)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
|
| 67 |
-
"""Opens a file stored inside a zip and returns a file-like object.
|
| 68 |
-
|
| 69 |
-
Args:
|
| 70 |
-
path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
|
| 71 |
-
mode (str): The mode in which to open the file with.
|
| 72 |
-
Returns:
|
| 73 |
-
A file-like object for PathInZip.
|
| 74 |
-
"""
|
| 75 |
-
zf = _cached_open_zip(path_in_zip.zip_path)
|
| 76 |
-
return zf.open(path_in_zip.file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/environment.py
DELETED
|
@@ -1,176 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Provides cluster and tools configuration across clusters (slurm, dora, utilities).
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
import logging
|
| 12 |
-
import os
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
import re
|
| 15 |
-
import typing as tp
|
| 16 |
-
|
| 17 |
-
import omegaconf
|
| 18 |
-
|
| 19 |
-
from .utils.cluster import _guess_cluster_type
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
logger = logging.getLogger(__name__)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class AudioCraftEnvironment:
|
| 26 |
-
"""Environment configuration for teams and clusters.
|
| 27 |
-
|
| 28 |
-
AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
|
| 29 |
-
or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
|
| 30 |
-
provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
|
| 31 |
-
allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
|
| 32 |
-
map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
|
| 33 |
-
|
| 34 |
-
The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
|
| 35 |
-
Use the following environment variables to specify the cluster, team or configuration:
|
| 36 |
-
|
| 37 |
-
AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
|
| 38 |
-
cannot be inferred automatically.
|
| 39 |
-
AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
|
| 40 |
-
If not set, configuration is read from config/teams.yaml.
|
| 41 |
-
AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
|
| 42 |
-
Cluster configuration are shared across teams to match compute allocation,
|
| 43 |
-
specify your cluster configuration in the configuration file under a key mapping
|
| 44 |
-
your team name.
|
| 45 |
-
"""
|
| 46 |
-
_instance = None
|
| 47 |
-
DEFAULT_TEAM = "default"
|
| 48 |
-
|
| 49 |
-
def __init__(self) -> None:
|
| 50 |
-
"""Loads configuration."""
|
| 51 |
-
self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
|
| 52 |
-
cluster_type = _guess_cluster_type()
|
| 53 |
-
cluster = os.getenv(
|
| 54 |
-
"AUDIOCRAFT_CLUSTER", cluster_type.value
|
| 55 |
-
)
|
| 56 |
-
logger.info("Detecting cluster type %s", cluster_type)
|
| 57 |
-
|
| 58 |
-
self.cluster: str = cluster
|
| 59 |
-
|
| 60 |
-
config_path = os.getenv(
|
| 61 |
-
"AUDIOCRAFT_CONFIG",
|
| 62 |
-
Path(__file__)
|
| 63 |
-
.parent.parent.joinpath("config/teams", self.team)
|
| 64 |
-
.with_suffix(".yaml"),
|
| 65 |
-
)
|
| 66 |
-
self.config = omegaconf.OmegaConf.load(config_path)
|
| 67 |
-
self._dataset_mappers = []
|
| 68 |
-
cluster_config = self._get_cluster_config()
|
| 69 |
-
if "dataset_mappers" in cluster_config:
|
| 70 |
-
for pattern, repl in cluster_config["dataset_mappers"].items():
|
| 71 |
-
regex = re.compile(pattern)
|
| 72 |
-
self._dataset_mappers.append((regex, repl))
|
| 73 |
-
|
| 74 |
-
def _get_cluster_config(self) -> omegaconf.DictConfig:
|
| 75 |
-
assert isinstance(self.config, omegaconf.DictConfig)
|
| 76 |
-
return self.config[self.cluster]
|
| 77 |
-
|
| 78 |
-
@classmethod
|
| 79 |
-
def instance(cls):
|
| 80 |
-
if cls._instance is None:
|
| 81 |
-
cls._instance = cls()
|
| 82 |
-
return cls._instance
|
| 83 |
-
|
| 84 |
-
@classmethod
|
| 85 |
-
def reset(cls):
|
| 86 |
-
"""Clears the environment and forces a reload on next invocation."""
|
| 87 |
-
cls._instance = None
|
| 88 |
-
|
| 89 |
-
@classmethod
|
| 90 |
-
def get_team(cls) -> str:
|
| 91 |
-
"""Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
|
| 92 |
-
If not defined, defaults to "labs".
|
| 93 |
-
"""
|
| 94 |
-
return cls.instance().team
|
| 95 |
-
|
| 96 |
-
@classmethod
|
| 97 |
-
def get_cluster(cls) -> str:
|
| 98 |
-
"""Gets the detected cluster.
|
| 99 |
-
This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
|
| 100 |
-
"""
|
| 101 |
-
return cls.instance().cluster
|
| 102 |
-
|
| 103 |
-
@classmethod
|
| 104 |
-
def get_dora_dir(cls) -> Path:
|
| 105 |
-
"""Gets the path to the dora directory for the current team and cluster.
|
| 106 |
-
Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
|
| 107 |
-
"""
|
| 108 |
-
cluster_config = cls.instance()._get_cluster_config()
|
| 109 |
-
dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
|
| 110 |
-
logger.warning(f"Dora directory: {dora_dir}")
|
| 111 |
-
return Path(dora_dir)
|
| 112 |
-
|
| 113 |
-
@classmethod
|
| 114 |
-
def get_reference_dir(cls) -> Path:
|
| 115 |
-
"""Gets the path to the reference directory for the current team and cluster.
|
| 116 |
-
Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
|
| 117 |
-
"""
|
| 118 |
-
cluster_config = cls.instance()._get_cluster_config()
|
| 119 |
-
return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
|
| 120 |
-
|
| 121 |
-
@classmethod
|
| 122 |
-
def get_slurm_exclude(cls) -> tp.Optional[str]:
|
| 123 |
-
"""Get the list of nodes to exclude for that cluster."""
|
| 124 |
-
cluster_config = cls.instance()._get_cluster_config()
|
| 125 |
-
return cluster_config.get("slurm_exclude")
|
| 126 |
-
|
| 127 |
-
@classmethod
|
| 128 |
-
def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
|
| 129 |
-
"""Gets the requested partitions for the current team and cluster as a comma-separated string.
|
| 130 |
-
|
| 131 |
-
Args:
|
| 132 |
-
partition_types (list[str], optional): partition types to retrieve. Values must be
|
| 133 |
-
from ['global', 'team']. If not provided, the global partition is returned.
|
| 134 |
-
"""
|
| 135 |
-
if not partition_types:
|
| 136 |
-
partition_types = ["global"]
|
| 137 |
-
|
| 138 |
-
cluster_config = cls.instance()._get_cluster_config()
|
| 139 |
-
partitions = [
|
| 140 |
-
cluster_config["partitions"][partition_type]
|
| 141 |
-
for partition_type in partition_types
|
| 142 |
-
]
|
| 143 |
-
return ",".join(partitions)
|
| 144 |
-
|
| 145 |
-
@classmethod
|
| 146 |
-
def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
|
| 147 |
-
"""Converts reference placeholder in path with configured reference dir to resolve paths.
|
| 148 |
-
|
| 149 |
-
Args:
|
| 150 |
-
path (str or Path): Path to resolve.
|
| 151 |
-
Returns:
|
| 152 |
-
Path: Resolved path.
|
| 153 |
-
"""
|
| 154 |
-
path = str(path)
|
| 155 |
-
|
| 156 |
-
if path.startswith("//reference"):
|
| 157 |
-
reference_dir = cls.get_reference_dir()
|
| 158 |
-
logger.warn(f"Reference directory: {reference_dir}")
|
| 159 |
-
assert (
|
| 160 |
-
reference_dir.exists() and reference_dir.is_dir()
|
| 161 |
-
), f"Reference directory does not exist: {reference_dir}."
|
| 162 |
-
path = re.sub("^//reference", str(reference_dir), path)
|
| 163 |
-
|
| 164 |
-
return Path(path)
|
| 165 |
-
|
| 166 |
-
@classmethod
|
| 167 |
-
def apply_dataset_mappers(cls, path: str) -> str:
|
| 168 |
-
"""Applies dataset mapping regex rules as defined in the configuration.
|
| 169 |
-
If no rules are defined, the path is returned as-is.
|
| 170 |
-
"""
|
| 171 |
-
instance = cls.instance()
|
| 172 |
-
|
| 173 |
-
for pattern, repl in instance._dataset_mappers:
|
| 174 |
-
path = pattern.sub(repl, path)
|
| 175 |
-
|
| 176 |
-
return path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/models/__init__.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
|
| 9 |
-
"""
|
| 10 |
-
# flake8: noqa
|
| 11 |
-
from . import builders, loaders
|
| 12 |
-
from .encodec import (
|
| 13 |
-
CompressionModel, EncodecModel, DAC,
|
| 14 |
-
HFEncodecModel, HFEncodecCompressionModel)
|
| 15 |
-
from .lm import LMModel
|
| 16 |
-
from .lm_magnet import MagnetLMModel
|
| 17 |
-
from .flow_matching import FlowMatchingModel
|
| 18 |
-
from .encodec import CompressionModel, EncodecModel
|
| 19 |
-
from .musicgen import MusicGen
|
| 20 |
-
from .magnet import MAGNeT
|
| 21 |
-
from .unet import DiffusionUnet
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/models/builders.py
DELETED
|
@@ -1,351 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
All the functions to build the relevant models and modules
|
| 9 |
-
from the Hydra config.
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
import typing as tp
|
| 13 |
-
|
| 14 |
-
import omegaconf
|
| 15 |
-
import torch
|
| 16 |
-
|
| 17 |
-
import audiocraft
|
| 18 |
-
|
| 19 |
-
from .. import quantization as qt
|
| 20 |
-
from ..modules.codebooks_patterns import (CoarseFirstPattern,
|
| 21 |
-
CodebooksPatternProvider,
|
| 22 |
-
DelayedPatternProvider,
|
| 23 |
-
MusicLMPattern,
|
| 24 |
-
ParallelPatternProvider,
|
| 25 |
-
UnrolledPatternProvider)
|
| 26 |
-
from ..modules.conditioners import (BaseConditioner, ChromaStemConditioner,
|
| 27 |
-
CLAPEmbeddingConditioner,
|
| 28 |
-
ConditionFuser, JascoCondConst,
|
| 29 |
-
ConditioningProvider, LUTConditioner,
|
| 30 |
-
T5Conditioner, StyleConditioner)
|
| 31 |
-
from ..modules.jasco_conditioners import (JascoConditioningProvider, ChordsEmbConditioner,
|
| 32 |
-
DrumsConditioner, MelodyConditioner)
|
| 33 |
-
from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
|
| 34 |
-
from ..utils.utils import dict_from_config
|
| 35 |
-
from .encodec import (CompressionModel, EncodecModel,
|
| 36 |
-
InterleaveStereoCompressionModel)
|
| 37 |
-
from .lm import LMModel
|
| 38 |
-
from .lm_magnet import MagnetLMModel
|
| 39 |
-
from .flow_matching import FlowMatchingModel
|
| 40 |
-
from .unet import DiffusionUnet
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def get_quantizer(
|
| 45 |
-
quantizer: str, cfg: omegaconf.DictConfig, dimension: int
|
| 46 |
-
) -> qt.BaseQuantizer:
|
| 47 |
-
klass = {"no_quant": qt.DummyQuantizer, "rvq": qt.ResidualVectorQuantizer}[
|
| 48 |
-
quantizer
|
| 49 |
-
]
|
| 50 |
-
kwargs = dict_from_config(getattr(cfg, quantizer))
|
| 51 |
-
if quantizer != "no_quant":
|
| 52 |
-
kwargs["dimension"] = dimension
|
| 53 |
-
return klass(**kwargs)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
| 57 |
-
if encoder_name == "seanet":
|
| 58 |
-
kwargs = dict_from_config(getattr(cfg, "seanet"))
|
| 59 |
-
encoder_override_kwargs = kwargs.pop("encoder")
|
| 60 |
-
decoder_override_kwargs = kwargs.pop("decoder")
|
| 61 |
-
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
|
| 62 |
-
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
| 63 |
-
encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
|
| 64 |
-
decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
|
| 65 |
-
return encoder, decoder
|
| 66 |
-
else:
|
| 67 |
-
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
|
| 71 |
-
"""Instantiate a compression model."""
|
| 72 |
-
if cfg.compression_model == "encodec":
|
| 73 |
-
kwargs = dict_from_config(getattr(cfg, "encodec"))
|
| 74 |
-
encoder_name = kwargs.pop("autoencoder")
|
| 75 |
-
quantizer_name = kwargs.pop("quantizer")
|
| 76 |
-
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
| 77 |
-
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
| 78 |
-
frame_rate = kwargs["sample_rate"] // encoder.hop_length
|
| 79 |
-
renormalize = kwargs.pop("renormalize", False)
|
| 80 |
-
# deprecated params
|
| 81 |
-
kwargs.pop("renorm", None)
|
| 82 |
-
return EncodecModel(
|
| 83 |
-
encoder,
|
| 84 |
-
decoder,
|
| 85 |
-
quantizer,
|
| 86 |
-
frame_rate=frame_rate,
|
| 87 |
-
renormalize=renormalize,
|
| 88 |
-
**kwargs,
|
| 89 |
-
).to(cfg.device)
|
| 90 |
-
else:
|
| 91 |
-
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def get_jasco_model(cfg: omegaconf.DictConfig,
|
| 95 |
-
compression_model: tp.Optional[CompressionModel] = None) -> FlowMatchingModel:
|
| 96 |
-
kwargs = dict_from_config(getattr(cfg, "transformer_lm"))
|
| 97 |
-
attribute_dropout = dict_from_config(getattr(cfg, "attribute_dropout"))
|
| 98 |
-
cls_free_guidance = dict_from_config(getattr(cfg, "classifier_free_guidance"))
|
| 99 |
-
cfg_prob = cls_free_guidance["training_dropout"]
|
| 100 |
-
cfg_coef = cls_free_guidance["inference_coef"]
|
| 101 |
-
fuser = get_condition_fuser(cfg)
|
| 102 |
-
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
|
| 103 |
-
if JascoCondConst.DRM.value in condition_provider.conditioners: # use self_wav for drums
|
| 104 |
-
assert compression_model is not None
|
| 105 |
-
|
| 106 |
-
# use compression model for drums conditioning
|
| 107 |
-
condition_provider.conditioners.self_wav.compression_model = compression_model
|
| 108 |
-
condition_provider.conditioners.self_wav.compression_model.requires_grad_(False)
|
| 109 |
-
|
| 110 |
-
# downcast to jasco conditioning provider
|
| 111 |
-
seq_len = cfg.compression_model_framerate * cfg.dataset.segment_duration
|
| 112 |
-
chords_card = cfg.conditioners.chords.chords_emb.card if JascoCondConst.CRD.value in cfg.conditioners else -1
|
| 113 |
-
condition_provider = JascoConditioningProvider(device=condition_provider.device,
|
| 114 |
-
conditioners=condition_provider.conditioners,
|
| 115 |
-
chords_card=chords_card,
|
| 116 |
-
sequence_length=seq_len)
|
| 117 |
-
|
| 118 |
-
if len(fuser.fuse2cond["cross"]) > 0: # enforce cross-att programmatically
|
| 119 |
-
kwargs["cross_attention"] = True
|
| 120 |
-
|
| 121 |
-
kwargs.pop("n_q", None)
|
| 122 |
-
kwargs.pop("card", None)
|
| 123 |
-
|
| 124 |
-
return FlowMatchingModel(
|
| 125 |
-
condition_provider=condition_provider,
|
| 126 |
-
fuser=fuser,
|
| 127 |
-
cfg_dropout=cfg_prob,
|
| 128 |
-
cfg_coef=cfg_coef,
|
| 129 |
-
attribute_dropout=attribute_dropout,
|
| 130 |
-
dtype=getattr(torch, cfg.dtype),
|
| 131 |
-
device=cfg.device,
|
| 132 |
-
**kwargs,
|
| 133 |
-
).to(cfg.device)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
| 137 |
-
"""Instantiate a transformer LM."""
|
| 138 |
-
if cfg.lm_model in ["transformer_lm", "transformer_lm_magnet"]:
|
| 139 |
-
kwargs = dict_from_config(getattr(cfg, "transformer_lm"))
|
| 140 |
-
n_q = kwargs["n_q"]
|
| 141 |
-
q_modeling = kwargs.pop("q_modeling", None)
|
| 142 |
-
codebooks_pattern_cfg = getattr(cfg, "codebooks_pattern")
|
| 143 |
-
attribute_dropout = dict_from_config(getattr(cfg, "attribute_dropout"))
|
| 144 |
-
cls_free_guidance = dict_from_config(getattr(cfg, "classifier_free_guidance"))
|
| 145 |
-
cfg_prob, cfg_coef = (
|
| 146 |
-
cls_free_guidance["training_dropout"],
|
| 147 |
-
cls_free_guidance["inference_coef"],
|
| 148 |
-
)
|
| 149 |
-
fuser = get_condition_fuser(cfg)
|
| 150 |
-
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
|
| 151 |
-
if len(fuser.fuse2cond["cross"]) > 0: # enforce cross-att programmatically
|
| 152 |
-
kwargs["cross_attention"] = True
|
| 153 |
-
if codebooks_pattern_cfg.modeling is None:
|
| 154 |
-
assert (
|
| 155 |
-
q_modeling is not None
|
| 156 |
-
), "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
|
| 157 |
-
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
|
| 158 |
-
{"modeling": q_modeling, "delay": {"delays": list(range(n_q))}}
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
-
pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
| 162 |
-
lm_class = MagnetLMModel if cfg.lm_model == "transformer_lm_magnet" else LMModel
|
| 163 |
-
return lm_class(
|
| 164 |
-
pattern_provider=pattern_provider,
|
| 165 |
-
condition_provider=condition_provider,
|
| 166 |
-
fuser=fuser,
|
| 167 |
-
cfg_dropout=cfg_prob,
|
| 168 |
-
cfg_coef=cfg_coef,
|
| 169 |
-
attribute_dropout=attribute_dropout,
|
| 170 |
-
dtype=getattr(torch, cfg.dtype),
|
| 171 |
-
device=cfg.device,
|
| 172 |
-
**kwargs,
|
| 173 |
-
).to(cfg.device)
|
| 174 |
-
else:
|
| 175 |
-
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def get_conditioner_provider(
|
| 179 |
-
output_dim: int, cfg: omegaconf.DictConfig
|
| 180 |
-
) -> ConditioningProvider:
|
| 181 |
-
"""Instantiate a conditioning model."""
|
| 182 |
-
device = cfg.device
|
| 183 |
-
duration = cfg.dataset.segment_duration
|
| 184 |
-
cfg = getattr(cfg, "conditioners")
|
| 185 |
-
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
| 186 |
-
conditioners: tp.Dict[str, BaseConditioner] = {}
|
| 187 |
-
condition_provider_args = dict_cfg.pop("args", {})
|
| 188 |
-
condition_provider_args.pop("merge_text_conditions_p", None)
|
| 189 |
-
condition_provider_args.pop("drop_desc_p", None)
|
| 190 |
-
|
| 191 |
-
for cond, cond_cfg in dict_cfg.items():
|
| 192 |
-
model_type = cond_cfg["model"]
|
| 193 |
-
model_args = cond_cfg[model_type]
|
| 194 |
-
if model_type == "t5":
|
| 195 |
-
conditioners[str(cond)] = T5Conditioner(
|
| 196 |
-
output_dim=output_dim, device=device, **model_args
|
| 197 |
-
)
|
| 198 |
-
elif model_type == "lut":
|
| 199 |
-
conditioners[str(cond)] = LUTConditioner(
|
| 200 |
-
output_dim=output_dim, **model_args
|
| 201 |
-
)
|
| 202 |
-
elif model_type == "chroma_stem":
|
| 203 |
-
conditioners[str(cond)] = ChromaStemConditioner(
|
| 204 |
-
output_dim=output_dim, duration=duration, device=device, **model_args
|
| 205 |
-
)
|
| 206 |
-
elif model_type in {"chords_emb", "drum_latents", "melody"}:
|
| 207 |
-
conditioners_classes = {"chords_emb": ChordsEmbConditioner,
|
| 208 |
-
"drum_latents": DrumsConditioner,
|
| 209 |
-
"melody": MelodyConditioner}
|
| 210 |
-
conditioner_class = conditioners_classes[model_type]
|
| 211 |
-
conditioners[str(cond)] = conditioner_class(device=device, **model_args)
|
| 212 |
-
elif model_type == "clap":
|
| 213 |
-
conditioners[str(cond)] = CLAPEmbeddingConditioner(
|
| 214 |
-
output_dim=output_dim, device=device, **model_args
|
| 215 |
-
)
|
| 216 |
-
elif model_type == 'style':
|
| 217 |
-
conditioners[str(cond)] = StyleConditioner(
|
| 218 |
-
output_dim=output_dim,
|
| 219 |
-
device=device,
|
| 220 |
-
**model_args
|
| 221 |
-
)
|
| 222 |
-
else:
|
| 223 |
-
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
| 224 |
-
conditioner = ConditioningProvider(
|
| 225 |
-
conditioners, device=device, **condition_provider_args
|
| 226 |
-
)
|
| 227 |
-
return conditioner
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
| 231 |
-
"""Instantiate a condition fuser object."""
|
| 232 |
-
fuser_cfg = getattr(cfg, "fuser")
|
| 233 |
-
fuser_methods = ["sum", "cross", "prepend", "ignore", "input_interpolate"]
|
| 234 |
-
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods if k in fuser_cfg}
|
| 235 |
-
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
| 236 |
-
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
| 237 |
-
return fuser
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
def get_codebooks_pattern_provider(
|
| 241 |
-
n_q: int, cfg: omegaconf.DictConfig
|
| 242 |
-
) -> CodebooksPatternProvider:
|
| 243 |
-
"""Instantiate a codebooks pattern provider object."""
|
| 244 |
-
pattern_providers = {
|
| 245 |
-
"parallel": ParallelPatternProvider,
|
| 246 |
-
"delay": DelayedPatternProvider,
|
| 247 |
-
"unroll": UnrolledPatternProvider,
|
| 248 |
-
"coarse_first": CoarseFirstPattern,
|
| 249 |
-
"musiclm": MusicLMPattern,
|
| 250 |
-
}
|
| 251 |
-
name = cfg.modeling
|
| 252 |
-
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
| 253 |
-
klass = pattern_providers[name]
|
| 254 |
-
return klass(n_q, **kwargs)
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
def get_debug_compression_model(device="cpu", sample_rate: int = 32000):
|
| 258 |
-
"""Instantiate a debug compression model to be used for unit tests."""
|
| 259 |
-
assert sample_rate in [
|
| 260 |
-
16000,
|
| 261 |
-
32000,
|
| 262 |
-
], "unsupported sample rate for debug compression model"
|
| 263 |
-
model_ratios = {
|
| 264 |
-
16000: [10, 8, 8], # 25 Hz at 16kHz
|
| 265 |
-
32000: [10, 8, 16], # 25 Hz at 32kHz
|
| 266 |
-
}
|
| 267 |
-
ratios: tp.List[int] = model_ratios[sample_rate]
|
| 268 |
-
frame_rate = 25
|
| 269 |
-
seanet_kwargs: dict = {
|
| 270 |
-
"n_filters": 4,
|
| 271 |
-
"n_residual_layers": 1,
|
| 272 |
-
"dimension": 32,
|
| 273 |
-
"ratios": ratios,
|
| 274 |
-
}
|
| 275 |
-
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
|
| 276 |
-
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
|
| 277 |
-
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
|
| 278 |
-
init_x = torch.randn(8, 32, 128)
|
| 279 |
-
quantizer(init_x, 1) # initialize kmeans etc.
|
| 280 |
-
compression_model = EncodecModel(
|
| 281 |
-
encoder,
|
| 282 |
-
decoder,
|
| 283 |
-
quantizer,
|
| 284 |
-
frame_rate=frame_rate,
|
| 285 |
-
sample_rate=sample_rate,
|
| 286 |
-
channels=1,
|
| 287 |
-
).to(device)
|
| 288 |
-
return compression_model.eval()
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
def get_diffusion_model(cfg: omegaconf.DictConfig):
|
| 292 |
-
# TODO Find a way to infer the channels from dset
|
| 293 |
-
channels = cfg.channels
|
| 294 |
-
num_steps = cfg.schedule.num_steps
|
| 295 |
-
return DiffusionUnet(chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
def get_processor(cfg, sample_rate: int = 24000):
|
| 299 |
-
sample_processor = SampleProcessor()
|
| 300 |
-
if cfg.use:
|
| 301 |
-
kw = dict(cfg)
|
| 302 |
-
kw.pop("use")
|
| 303 |
-
kw.pop("name")
|
| 304 |
-
if cfg.name == "multi_band_processor":
|
| 305 |
-
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
|
| 306 |
-
return sample_processor
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
def get_debug_lm_model(device="cpu"):
|
| 310 |
-
"""Instantiate a debug LM to be used for unit tests."""
|
| 311 |
-
pattern = DelayedPatternProvider(n_q=4)
|
| 312 |
-
dim = 16
|
| 313 |
-
providers = {
|
| 314 |
-
"description": LUTConditioner(
|
| 315 |
-
n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"
|
| 316 |
-
),
|
| 317 |
-
}
|
| 318 |
-
condition_provider = ConditioningProvider(providers)
|
| 319 |
-
fuser = ConditionFuser(
|
| 320 |
-
{"cross": ["description"], "prepend": [], "sum": [], "input_interpolate": []}
|
| 321 |
-
)
|
| 322 |
-
lm = LMModel(
|
| 323 |
-
pattern,
|
| 324 |
-
condition_provider,
|
| 325 |
-
fuser,
|
| 326 |
-
n_q=4,
|
| 327 |
-
card=400,
|
| 328 |
-
dim=dim,
|
| 329 |
-
num_heads=4,
|
| 330 |
-
custom=True,
|
| 331 |
-
num_layers=2,
|
| 332 |
-
cross_attention=True,
|
| 333 |
-
causal=True,
|
| 334 |
-
)
|
| 335 |
-
return lm.to(device).eval()
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
def get_wrapped_compression_model(
|
| 339 |
-
compression_model: CompressionModel, cfg: omegaconf.DictConfig
|
| 340 |
-
) -> CompressionModel:
|
| 341 |
-
if hasattr(cfg, "interleave_stereo_codebooks"):
|
| 342 |
-
if cfg.interleave_stereo_codebooks.use:
|
| 343 |
-
kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
|
| 344 |
-
kwargs.pop("use")
|
| 345 |
-
compression_model = InterleaveStereoCompressionModel(
|
| 346 |
-
compression_model, **kwargs
|
| 347 |
-
)
|
| 348 |
-
if hasattr(cfg, "compression_model_n_q"):
|
| 349 |
-
if cfg.compression_model_n_q is not None:
|
| 350 |
-
compression_model.set_num_codebooks(cfg.compression_model_n_q)
|
| 351 |
-
return compression_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/models/encodec.py
DELETED
|
@@ -1,506 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
"""Compression models or wrapper around existing models.
|
| 7 |
-
Also defines the main interface that a model must follow to be usable as an audio tokenizer.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
from abc import ABC, abstractmethod
|
| 11 |
-
import logging
|
| 12 |
-
import math
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
import typing as tp
|
| 15 |
-
|
| 16 |
-
from einops import rearrange
|
| 17 |
-
import numpy as np
|
| 18 |
-
import torch
|
| 19 |
-
from torch import nn
|
| 20 |
-
from transformers import EncodecModel as HFEncodecModel
|
| 21 |
-
|
| 22 |
-
from .. import quantization as qt
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
logger = logging.getLogger()
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class CompressionModel(ABC, nn.Module):
|
| 29 |
-
"""Base API for all compression model that aim at being used as audio tokenizers
|
| 30 |
-
with a language model.
|
| 31 |
-
"""
|
| 32 |
-
|
| 33 |
-
@abstractmethod
|
| 34 |
-
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
| 35 |
-
...
|
| 36 |
-
|
| 37 |
-
@abstractmethod
|
| 38 |
-
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
| 39 |
-
"""See `EncodecModel.encode`."""
|
| 40 |
-
...
|
| 41 |
-
|
| 42 |
-
@abstractmethod
|
| 43 |
-
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
| 44 |
-
"""See `EncodecModel.decode`."""
|
| 45 |
-
...
|
| 46 |
-
|
| 47 |
-
@abstractmethod
|
| 48 |
-
def decode_latent(self, codes: torch.Tensor):
|
| 49 |
-
"""Decode from the discrete codes to continuous latent space."""
|
| 50 |
-
...
|
| 51 |
-
|
| 52 |
-
@property
|
| 53 |
-
@abstractmethod
|
| 54 |
-
def channels(self) -> int:
|
| 55 |
-
...
|
| 56 |
-
|
| 57 |
-
@property
|
| 58 |
-
@abstractmethod
|
| 59 |
-
def frame_rate(self) -> float:
|
| 60 |
-
...
|
| 61 |
-
|
| 62 |
-
@property
|
| 63 |
-
@abstractmethod
|
| 64 |
-
def sample_rate(self) -> int:
|
| 65 |
-
...
|
| 66 |
-
|
| 67 |
-
@property
|
| 68 |
-
@abstractmethod
|
| 69 |
-
def cardinality(self) -> int:
|
| 70 |
-
...
|
| 71 |
-
|
| 72 |
-
@property
|
| 73 |
-
@abstractmethod
|
| 74 |
-
def num_codebooks(self) -> int:
|
| 75 |
-
...
|
| 76 |
-
|
| 77 |
-
@property
|
| 78 |
-
@abstractmethod
|
| 79 |
-
def total_codebooks(self) -> int:
|
| 80 |
-
...
|
| 81 |
-
|
| 82 |
-
@abstractmethod
|
| 83 |
-
def set_num_codebooks(self, n: int):
|
| 84 |
-
"""Set the active number of codebooks used by the quantizer."""
|
| 85 |
-
...
|
| 86 |
-
|
| 87 |
-
@staticmethod
|
| 88 |
-
def get_pretrained(
|
| 89 |
-
name: str, device: tp.Union[torch.device, str] = 'cpu'
|
| 90 |
-
) -> 'CompressionModel':
|
| 91 |
-
"""Instantiate a CompressionModel from a given pretrained model.
|
| 92 |
-
|
| 93 |
-
Args:
|
| 94 |
-
name (Path or str): name of the pretrained model. See after.
|
| 95 |
-
device (torch.device or str): Device on which the model is loaded.
|
| 96 |
-
|
| 97 |
-
Pretrained models:
|
| 98 |
-
- dac_44khz (https://github.com/descriptinc/descript-audio-codec)
|
| 99 |
-
- dac_24khz (same)
|
| 100 |
-
- facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
|
| 101 |
-
- facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
|
| 102 |
-
- your own model on HugginFace. Export instructions to come...
|
| 103 |
-
"""
|
| 104 |
-
|
| 105 |
-
from . import builders, loaders
|
| 106 |
-
model: CompressionModel
|
| 107 |
-
if name in ['dac_44khz', 'dac_24khz']:
|
| 108 |
-
model_type = name.split('_')[1]
|
| 109 |
-
logger.info("Getting pretrained compression model from DAC %s", model_type)
|
| 110 |
-
model = DAC(model_type)
|
| 111 |
-
elif name in ['debug_compression_model']:
|
| 112 |
-
logger.info("Getting pretrained compression model for debug")
|
| 113 |
-
model = builders.get_debug_compression_model()
|
| 114 |
-
elif Path(name).exists():
|
| 115 |
-
# We assume here if the paths exist that it is in fact an AC checkpoint
|
| 116 |
-
# that was exported using `audiocraft.utils.export` functions.
|
| 117 |
-
model = loaders.load_compression_model(name, device=device)
|
| 118 |
-
else:
|
| 119 |
-
logger.info("Getting pretrained compression model from HF %s", name)
|
| 120 |
-
hf_model = HFEncodecModel.from_pretrained(name)
|
| 121 |
-
model = HFEncodecCompressionModel(hf_model).to(device)
|
| 122 |
-
return model.to(device).eval()
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
class EncodecModel(CompressionModel):
|
| 126 |
-
"""Encodec model operating on the raw waveform.
|
| 127 |
-
|
| 128 |
-
Args:
|
| 129 |
-
encoder (nn.Module): Encoder network.
|
| 130 |
-
decoder (nn.Module): Decoder network.
|
| 131 |
-
quantizer (qt.BaseQuantizer): Quantizer network.
|
| 132 |
-
frame_rate (int): Frame rate for the latent representation.
|
| 133 |
-
sample_rate (int): Audio sample rate.
|
| 134 |
-
channels (int): Number of audio channels.
|
| 135 |
-
causal (bool): Whether to use a causal version of the model.
|
| 136 |
-
renormalize (bool): Whether to renormalize the audio before running the model.
|
| 137 |
-
"""
|
| 138 |
-
# we need assignment to override the property in the abstract class,
|
| 139 |
-
# I couldn't find a better way...
|
| 140 |
-
frame_rate: float = 0
|
| 141 |
-
sample_rate: int = 0
|
| 142 |
-
channels: int = 0
|
| 143 |
-
|
| 144 |
-
def __init__(self,
|
| 145 |
-
encoder: nn.Module,
|
| 146 |
-
decoder: nn.Module,
|
| 147 |
-
quantizer: qt.BaseQuantizer,
|
| 148 |
-
frame_rate: int,
|
| 149 |
-
sample_rate: int,
|
| 150 |
-
channels: int,
|
| 151 |
-
causal: bool = False,
|
| 152 |
-
renormalize: bool = False):
|
| 153 |
-
super().__init__()
|
| 154 |
-
self.encoder = encoder
|
| 155 |
-
self.decoder = decoder
|
| 156 |
-
self.quantizer = quantizer
|
| 157 |
-
self.frame_rate = frame_rate
|
| 158 |
-
self.sample_rate = sample_rate
|
| 159 |
-
self.channels = channels
|
| 160 |
-
self.renormalize = renormalize
|
| 161 |
-
self.causal = causal
|
| 162 |
-
if self.causal:
|
| 163 |
-
# we force disabling here to avoid handling linear overlap of segments
|
| 164 |
-
# as supported in original EnCodec codebase.
|
| 165 |
-
assert not self.renormalize, 'Causal model does not support renormalize'
|
| 166 |
-
|
| 167 |
-
@property
|
| 168 |
-
def total_codebooks(self):
|
| 169 |
-
"""Total number of quantizer codebooks available."""
|
| 170 |
-
return self.quantizer.total_codebooks
|
| 171 |
-
|
| 172 |
-
@property
|
| 173 |
-
def num_codebooks(self):
|
| 174 |
-
"""Active number of codebooks used by the quantizer."""
|
| 175 |
-
return self.quantizer.num_codebooks
|
| 176 |
-
|
| 177 |
-
def set_num_codebooks(self, n: int):
|
| 178 |
-
"""Set the active number of codebooks used by the quantizer."""
|
| 179 |
-
self.quantizer.set_num_codebooks(n)
|
| 180 |
-
|
| 181 |
-
@property
|
| 182 |
-
def cardinality(self):
|
| 183 |
-
"""Cardinality of each codebook."""
|
| 184 |
-
return self.quantizer.bins
|
| 185 |
-
|
| 186 |
-
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
| 187 |
-
scale: tp.Optional[torch.Tensor]
|
| 188 |
-
if self.renormalize:
|
| 189 |
-
mono = x.mean(dim=1, keepdim=True)
|
| 190 |
-
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
| 191 |
-
scale = 1e-8 + volume
|
| 192 |
-
x = x / scale
|
| 193 |
-
scale = scale.view(-1, 1)
|
| 194 |
-
else:
|
| 195 |
-
scale = None
|
| 196 |
-
return x, scale
|
| 197 |
-
|
| 198 |
-
def postprocess(self,
|
| 199 |
-
x: torch.Tensor,
|
| 200 |
-
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 201 |
-
if scale is not None:
|
| 202 |
-
assert self.renormalize
|
| 203 |
-
x = x * scale.view(-1, 1, 1)
|
| 204 |
-
return x
|
| 205 |
-
|
| 206 |
-
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
| 207 |
-
assert x.dim() == 3
|
| 208 |
-
length = x.shape[-1]
|
| 209 |
-
x, scale = self.preprocess(x)
|
| 210 |
-
|
| 211 |
-
emb = self.encoder(x)
|
| 212 |
-
q_res = self.quantizer(emb, self.frame_rate)
|
| 213 |
-
out = self.decoder(q_res.x)
|
| 214 |
-
|
| 215 |
-
# remove extra padding added by the encoder and decoder
|
| 216 |
-
assert out.shape[-1] >= length, (out.shape[-1], length)
|
| 217 |
-
out = out[..., :length]
|
| 218 |
-
|
| 219 |
-
q_res.x = self.postprocess(out, scale)
|
| 220 |
-
|
| 221 |
-
return q_res
|
| 222 |
-
|
| 223 |
-
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
| 224 |
-
"""Encode the given input tensor to quantized representation along with scale parameter.
|
| 225 |
-
|
| 226 |
-
Args:
|
| 227 |
-
x (torch.Tensor): Float tensor of shape [B, C, T]
|
| 228 |
-
|
| 229 |
-
Returns:
|
| 230 |
-
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
|
| 231 |
-
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
| 232 |
-
scale a float tensor containing the scale for audio renormalizealization.
|
| 233 |
-
"""
|
| 234 |
-
assert x.dim() == 3
|
| 235 |
-
x, scale = self.preprocess(x)
|
| 236 |
-
emb = self.encoder(x)
|
| 237 |
-
codes = self.quantizer.encode(emb)
|
| 238 |
-
return codes, scale
|
| 239 |
-
|
| 240 |
-
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
| 241 |
-
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
| 242 |
-
audio denormalization if needed.
|
| 243 |
-
|
| 244 |
-
Args:
|
| 245 |
-
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
| 246 |
-
scale (torch.Tensor, optional): Float tensor containing the scale value.
|
| 247 |
-
|
| 248 |
-
Returns:
|
| 249 |
-
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
| 250 |
-
"""
|
| 251 |
-
emb = self.decode_latent(codes)
|
| 252 |
-
out = self.decoder(emb)
|
| 253 |
-
out = self.postprocess(out, scale)
|
| 254 |
-
# out contains extra padding added by the encoder and decoder
|
| 255 |
-
return out
|
| 256 |
-
|
| 257 |
-
def decode_latent(self, codes: torch.Tensor):
|
| 258 |
-
"""Decode from the discrete codes to continuous latent space."""
|
| 259 |
-
return self.quantizer.decode(codes)
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
class DAC(CompressionModel):
|
| 263 |
-
def __init__(self, model_type: str = "44khz"):
|
| 264 |
-
super().__init__()
|
| 265 |
-
try:
|
| 266 |
-
import dac.utils
|
| 267 |
-
except ImportError:
|
| 268 |
-
raise RuntimeError("Could not import dac, make sure it is installed, "
|
| 269 |
-
"please run `pip install descript-audio-codec`")
|
| 270 |
-
self.model = dac.utils.load_model(model_type=model_type)
|
| 271 |
-
self.n_quantizers = self.total_codebooks
|
| 272 |
-
self.model.eval()
|
| 273 |
-
|
| 274 |
-
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
| 275 |
-
# We don't support training with this.
|
| 276 |
-
raise NotImplementedError("Forward and training with DAC not supported.")
|
| 277 |
-
|
| 278 |
-
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
| 279 |
-
codes = self.model.encode(x, self.n_quantizers)[1]
|
| 280 |
-
return codes[:, :self.n_quantizers], None
|
| 281 |
-
|
| 282 |
-
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
| 283 |
-
assert scale is None
|
| 284 |
-
z_q = self.decode_latent(codes)
|
| 285 |
-
return self.model.decode(z_q)
|
| 286 |
-
|
| 287 |
-
def decode_latent(self, codes: torch.Tensor):
|
| 288 |
-
"""Decode from the discrete codes to continuous latent space."""
|
| 289 |
-
return self.model.quantizer.from_codes(codes)[0]
|
| 290 |
-
|
| 291 |
-
@property
|
| 292 |
-
def channels(self) -> int:
|
| 293 |
-
return 1
|
| 294 |
-
|
| 295 |
-
@property
|
| 296 |
-
def frame_rate(self) -> float:
|
| 297 |
-
return self.model.sample_rate / self.model.hop_length
|
| 298 |
-
|
| 299 |
-
@property
|
| 300 |
-
def sample_rate(self) -> int:
|
| 301 |
-
return self.model.sample_rate
|
| 302 |
-
|
| 303 |
-
@property
|
| 304 |
-
def cardinality(self) -> int:
|
| 305 |
-
return self.model.codebook_size
|
| 306 |
-
|
| 307 |
-
@property
|
| 308 |
-
def num_codebooks(self) -> int:
|
| 309 |
-
return self.n_quantizers
|
| 310 |
-
|
| 311 |
-
@property
|
| 312 |
-
def total_codebooks(self) -> int:
|
| 313 |
-
return self.model.n_codebooks
|
| 314 |
-
|
| 315 |
-
def set_num_codebooks(self, n: int):
|
| 316 |
-
"""Set the active number of codebooks used by the quantizer.
|
| 317 |
-
"""
|
| 318 |
-
assert n >= 1
|
| 319 |
-
assert n <= self.total_codebooks
|
| 320 |
-
self.n_quantizers = n
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
class HFEncodecCompressionModel(CompressionModel):
|
| 324 |
-
"""Wrapper around HuggingFace Encodec.
|
| 325 |
-
"""
|
| 326 |
-
def __init__(self, model: HFEncodecModel):
|
| 327 |
-
super().__init__()
|
| 328 |
-
self.model = model
|
| 329 |
-
bws = self.model.config.target_bandwidths
|
| 330 |
-
num_codebooks = [
|
| 331 |
-
bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
|
| 332 |
-
for bw in bws
|
| 333 |
-
]
|
| 334 |
-
deltas = [nc - int(nc) for nc in num_codebooks]
|
| 335 |
-
# Checking we didn't do some bad maths and we indeed have integers!
|
| 336 |
-
assert all(deltas) <= 1e-3, deltas
|
| 337 |
-
self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
|
| 338 |
-
self.set_num_codebooks(max(self.possible_num_codebooks))
|
| 339 |
-
|
| 340 |
-
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
| 341 |
-
# We don't support training with this.
|
| 342 |
-
raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
|
| 343 |
-
|
| 344 |
-
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
| 345 |
-
bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
|
| 346 |
-
bandwidth = self.model.config.target_bandwidths[bandwidth_index]
|
| 347 |
-
res = self.model.encode(x, None, bandwidth)
|
| 348 |
-
assert len(res[0]) == 1
|
| 349 |
-
assert len(res[1]) == 1
|
| 350 |
-
return res[0][0], res[1][0]
|
| 351 |
-
|
| 352 |
-
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
| 353 |
-
if scale is None:
|
| 354 |
-
scales = [None] # type: ignore
|
| 355 |
-
else:
|
| 356 |
-
scales = scale # type: ignore
|
| 357 |
-
res = self.model.decode(codes[None], scales)
|
| 358 |
-
return res[0]
|
| 359 |
-
|
| 360 |
-
def decode_latent(self, codes: torch.Tensor):
|
| 361 |
-
"""Decode from the discrete codes to continuous latent space."""
|
| 362 |
-
return self.model.quantizer.decode(codes.transpose(0, 1))
|
| 363 |
-
|
| 364 |
-
@property
|
| 365 |
-
def channels(self) -> int:
|
| 366 |
-
return self.model.config.audio_channels
|
| 367 |
-
|
| 368 |
-
@property
|
| 369 |
-
def frame_rate(self) -> float:
|
| 370 |
-
hop_length = int(np.prod(self.model.config.upsampling_ratios))
|
| 371 |
-
return self.sample_rate / hop_length
|
| 372 |
-
|
| 373 |
-
@property
|
| 374 |
-
def sample_rate(self) -> int:
|
| 375 |
-
return self.model.config.sampling_rate
|
| 376 |
-
|
| 377 |
-
@property
|
| 378 |
-
def cardinality(self) -> int:
|
| 379 |
-
return self.model.config.codebook_size
|
| 380 |
-
|
| 381 |
-
@property
|
| 382 |
-
def num_codebooks(self) -> int:
|
| 383 |
-
return self._num_codebooks
|
| 384 |
-
|
| 385 |
-
@property
|
| 386 |
-
def total_codebooks(self) -> int:
|
| 387 |
-
return max(self.possible_num_codebooks)
|
| 388 |
-
|
| 389 |
-
def set_num_codebooks(self, n: int):
|
| 390 |
-
"""Set the active number of codebooks used by the quantizer.
|
| 391 |
-
"""
|
| 392 |
-
if n not in self.possible_num_codebooks:
|
| 393 |
-
raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
|
| 394 |
-
self._num_codebooks = n
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
class InterleaveStereoCompressionModel(CompressionModel):
|
| 398 |
-
"""Wraps a CompressionModel to support stereo inputs. The wrapped model
|
| 399 |
-
will be applied independently to the left and right channels, and both codebooks
|
| 400 |
-
will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per
|
| 401 |
-
channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on
|
| 402 |
-
`per_timestep`.
|
| 403 |
-
|
| 404 |
-
Args:
|
| 405 |
-
model (CompressionModel): Compression model to wrap.
|
| 406 |
-
per_timestep (bool): Whether to interleave on the timestep dimension
|
| 407 |
-
or on the codebooks dimension.
|
| 408 |
-
"""
|
| 409 |
-
def __init__(self, model: CompressionModel, per_timestep: bool = False):
|
| 410 |
-
super().__init__()
|
| 411 |
-
self.model = model
|
| 412 |
-
self.per_timestep = per_timestep
|
| 413 |
-
assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio"
|
| 414 |
-
|
| 415 |
-
@property
|
| 416 |
-
def total_codebooks(self):
|
| 417 |
-
return self.model.total_codebooks
|
| 418 |
-
|
| 419 |
-
@property
|
| 420 |
-
def num_codebooks(self):
|
| 421 |
-
"""Active number of codebooks used by the quantizer.
|
| 422 |
-
|
| 423 |
-
..Warning:: this reports the number of codebooks after the interleaving
|
| 424 |
-
of the codebooks!
|
| 425 |
-
"""
|
| 426 |
-
return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2
|
| 427 |
-
|
| 428 |
-
def set_num_codebooks(self, n: int):
|
| 429 |
-
"""Set the active number of codebooks used by the quantizer.
|
| 430 |
-
|
| 431 |
-
..Warning:: this sets the number of codebooks before the interleaving!
|
| 432 |
-
"""
|
| 433 |
-
self.model.set_num_codebooks(n)
|
| 434 |
-
|
| 435 |
-
@property
|
| 436 |
-
def num_virtual_steps(self) -> float:
|
| 437 |
-
"""Return the number of virtual steps, e.g. one real step
|
| 438 |
-
will be split into that many steps.
|
| 439 |
-
"""
|
| 440 |
-
return 2 if self.per_timestep else 1
|
| 441 |
-
|
| 442 |
-
@property
|
| 443 |
-
def frame_rate(self) -> float:
|
| 444 |
-
return self.model.frame_rate * self.num_virtual_steps
|
| 445 |
-
|
| 446 |
-
@property
|
| 447 |
-
def sample_rate(self) -> int:
|
| 448 |
-
return self.model.sample_rate
|
| 449 |
-
|
| 450 |
-
@property
|
| 451 |
-
def channels(self) -> int:
|
| 452 |
-
return 2
|
| 453 |
-
|
| 454 |
-
@property
|
| 455 |
-
def cardinality(self):
|
| 456 |
-
"""Cardinality of each codebook.
|
| 457 |
-
"""
|
| 458 |
-
return self.model.cardinality
|
| 459 |
-
|
| 460 |
-
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
| 461 |
-
raise NotImplementedError("Not supported, use encode and decode.")
|
| 462 |
-
|
| 463 |
-
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
| 464 |
-
B, C, T = x.shape
|
| 465 |
-
assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}"
|
| 466 |
-
|
| 467 |
-
indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1))
|
| 468 |
-
indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1))
|
| 469 |
-
indices = torch.stack([indices_c0, indices_c1], dim=0)
|
| 470 |
-
scales: tp.Optional[torch.Tensor] = None
|
| 471 |
-
if scales_c0 is not None and scales_c1 is not None:
|
| 472 |
-
scales = torch.stack([scales_c0, scales_c1], dim=1)
|
| 473 |
-
|
| 474 |
-
if self.per_timestep:
|
| 475 |
-
indices = rearrange(indices, 'c b k t -> b k (t c)', c=2)
|
| 476 |
-
else:
|
| 477 |
-
indices = rearrange(indices, 'c b k t -> b (k c) t', c=2)
|
| 478 |
-
|
| 479 |
-
return (indices, scales)
|
| 480 |
-
|
| 481 |
-
def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 482 |
-
if self.per_timestep:
|
| 483 |
-
codes = rearrange(codes, 'b k (t c) -> c b k t', c=2)
|
| 484 |
-
else:
|
| 485 |
-
codes = rearrange(codes, 'b (k c) t -> c b k t', c=2)
|
| 486 |
-
return codes[0], codes[1]
|
| 487 |
-
|
| 488 |
-
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
| 489 |
-
B, K, T = codes.shape
|
| 490 |
-
assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match"
|
| 491 |
-
assert K == self.num_codebooks, "Provided codes' number of codebooks does not match"
|
| 492 |
-
|
| 493 |
-
scale_c0, scale_c1 = None, None
|
| 494 |
-
if scale is not None:
|
| 495 |
-
assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}"
|
| 496 |
-
scale_c0 = scale[0, ...]
|
| 497 |
-
scale_c1 = scale[1, ...]
|
| 498 |
-
|
| 499 |
-
codes_c0, codes_c1 = self.get_left_right_codes(codes)
|
| 500 |
-
audio_c0 = self.model.decode(codes_c0, scale_c0)
|
| 501 |
-
audio_c1 = self.model.decode(codes_c1, scale_c1)
|
| 502 |
-
return torch.cat([audio_c0, audio_c1], dim=1)
|
| 503 |
-
|
| 504 |
-
def decode_latent(self, codes: torch.Tensor):
|
| 505 |
-
"""Decode from the discrete codes to continuous latent space."""
|
| 506 |
-
raise NotImplementedError("Not supported by interleaved stereo wrapped models.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/models/flow_matching.py
DELETED
|
@@ -1,516 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from dataclasses import dataclass
|
| 8 |
-
from functools import partial
|
| 9 |
-
import logging
|
| 10 |
-
import math
|
| 11 |
-
import typing as tp
|
| 12 |
-
import torch
|
| 13 |
-
from torch import nn
|
| 14 |
-
from torchdiffeq import odeint # type: ignore
|
| 15 |
-
from ..modules.streaming import StreamingModule
|
| 16 |
-
from ..modules.transformer import create_norm_fn, StreamingTransformerLayer
|
| 17 |
-
from ..modules.unet_transformer import UnetTransformer
|
| 18 |
-
from ..modules.conditioners import (
|
| 19 |
-
ConditionFuser,
|
| 20 |
-
ClassifierFreeGuidanceDropout,
|
| 21 |
-
AttributeDropout,
|
| 22 |
-
ConditioningAttributes,
|
| 23 |
-
JascoCondConst
|
| 24 |
-
)
|
| 25 |
-
from ..modules.jasco_conditioners import JascoConditioningProvider
|
| 26 |
-
from ..modules.activations import get_activation_fn
|
| 27 |
-
|
| 28 |
-
from .lm import ConditionTensors, init_layer
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
logger = logging.getLogger(__name__)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
@dataclass
|
| 35 |
-
class FMOutput:
|
| 36 |
-
latents: torch.Tensor # [B, T, D]
|
| 37 |
-
mask: torch.Tensor # [B, T]
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class CFGTerm:
|
| 41 |
-
"""
|
| 42 |
-
Base class for Multi Source Classifier-Free Guidance (CFG) terms. This class represents a term in the CFG process,
|
| 43 |
-
which is used to guide the generation process by adjusting the influence of different conditions.
|
| 44 |
-
Attributes:
|
| 45 |
-
conditions (dict): A dictionary of conditions that influence the generation process.
|
| 46 |
-
weight (float): The weight of the CFG term, determining its influence on the generation.
|
| 47 |
-
"""
|
| 48 |
-
def __init__(self, conditions, weight):
|
| 49 |
-
self.conditions = conditions
|
| 50 |
-
self.weight = weight
|
| 51 |
-
|
| 52 |
-
def drop_irrelevant_conds(self, conditions):
|
| 53 |
-
"""
|
| 54 |
-
Drops irrelevant conditions from the CFG term. This method should be implemented by subclasses.
|
| 55 |
-
Args:
|
| 56 |
-
conditions (dict): The conditions to be filtered.
|
| 57 |
-
Raises:
|
| 58 |
-
NotImplementedError: If the method is not implemented in a subclass.
|
| 59 |
-
"""
|
| 60 |
-
raise NotImplementedError("No base implementation for setting generation params.")
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class AllCFGTerm(CFGTerm):
|
| 64 |
-
"""
|
| 65 |
-
A CFG term that retains all conditions. This class does not drop any condition.
|
| 66 |
-
"""
|
| 67 |
-
def __init__(self, conditions, weight):
|
| 68 |
-
super().__init__(conditions, weight)
|
| 69 |
-
self.drop_irrelevant_conds()
|
| 70 |
-
|
| 71 |
-
def drop_irrelevant_conds(self):
|
| 72 |
-
pass
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
class NullCFGTerm(CFGTerm):
|
| 76 |
-
"""
|
| 77 |
-
A CFG term that drops all conditions, effectively nullifying their influence.
|
| 78 |
-
"""
|
| 79 |
-
def __init__(self, conditions, weight):
|
| 80 |
-
super().__init__(conditions, weight)
|
| 81 |
-
self.drop_irrelevant_conds()
|
| 82 |
-
|
| 83 |
-
def drop_irrelevant_conds(self):
|
| 84 |
-
"""
|
| 85 |
-
Drops all conditions by applying a dropout with probability 1.0, effectively nullifying their influence.
|
| 86 |
-
"""
|
| 87 |
-
self.conditions = ClassifierFreeGuidanceDropout(p=1.0)(
|
| 88 |
-
samples=self.conditions,
|
| 89 |
-
cond_types=["wav", "text", "symbolic"])
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
class TextCFGTerm(CFGTerm):
|
| 93 |
-
"""
|
| 94 |
-
A CFG term that selectively drops conditions based on specified dropout probabilities for different types
|
| 95 |
-
of conditions, such as 'symbolic' and 'wav'.
|
| 96 |
-
"""
|
| 97 |
-
def __init__(self, conditions, weight, model_att_dropout):
|
| 98 |
-
"""
|
| 99 |
-
Initializes a TextCFGTerm with specified conditions, weight, and model attention dropout configuration.
|
| 100 |
-
Args:
|
| 101 |
-
conditions (dict): The conditions to be used in the CFG process.
|
| 102 |
-
weight (float): The weight of the CFG term.
|
| 103 |
-
model_att_dropout (object): The attribute dropouts used by the model.
|
| 104 |
-
"""
|
| 105 |
-
super().__init__(conditions, weight)
|
| 106 |
-
if 'symbolic' in model_att_dropout.p:
|
| 107 |
-
self.drop_symbolics = {k: 1.0 for k in model_att_dropout.p['symbolic'].keys()}
|
| 108 |
-
else:
|
| 109 |
-
self.drop_symbolics = {}
|
| 110 |
-
if 'wav' in model_att_dropout.p:
|
| 111 |
-
self.drop_wav = {k: 1.0 for k in model_att_dropout.p['wav'].keys()}
|
| 112 |
-
else:
|
| 113 |
-
self.drop_wav = {}
|
| 114 |
-
self.drop_irrelevant_conds()
|
| 115 |
-
|
| 116 |
-
def drop_irrelevant_conds(self):
|
| 117 |
-
self.conditions = AttributeDropout({'symbolic': self.drop_symbolics,
|
| 118 |
-
'wav': self.drop_wav})(self.conditions) # drop temporal conds
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
class FlowMatchingModel(StreamingModule):
|
| 122 |
-
"""
|
| 123 |
-
A flow matching model inherits from StreamingModule.
|
| 124 |
-
This model uses a transformer architecture to process and fuse conditions, applying learned embeddings and
|
| 125 |
-
transformations and predicts multi-source guided vector fields.
|
| 126 |
-
Attributes:
|
| 127 |
-
condition_provider (JascoConditioningProvider): Provider for conditioning attributes.
|
| 128 |
-
fuser (ConditionFuser): Fuser for combining multiple conditions.
|
| 129 |
-
dim (int): Dimensionality of the model's main features.
|
| 130 |
-
num_heads (int): Number of attention heads in the transformer.
|
| 131 |
-
flow_dim (int): Dimensionality of the flow features.
|
| 132 |
-
chords_dim (int): Dimensionality for chord embeddings, if used.
|
| 133 |
-
drums_dim (int): Dimensionality for drums embeddings, if used.
|
| 134 |
-
melody_dim (int): Dimensionality for melody embeddings, if used.
|
| 135 |
-
hidden_scale (int): Scaling factor for the dimensionality of the feedforward network in the transformer.
|
| 136 |
-
norm (str): Type of normalization to use ('layer_norm' or other supported types).
|
| 137 |
-
norm_first (bool): Whether to apply normalization before other operations in the transformer layers.
|
| 138 |
-
bias_proj (bool): Whether to include bias in the projection layers.
|
| 139 |
-
weight_init (Optional[str]): Method for initializing weights.
|
| 140 |
-
depthwise_init (Optional[str]): Method for initializing depthwise convolutional layers.
|
| 141 |
-
zero_bias_init (bool): Whether to initialize biases to zero.
|
| 142 |
-
cfg_dropout (float): Dropout rate for configuration settings.
|
| 143 |
-
cfg_coef (float): Coefficient for configuration influence.
|
| 144 |
-
attribute_dropout (Dict[str, Dict[str, float]]): Dropout rates for specific attributes.
|
| 145 |
-
time_embedding_dim (int): Dimensionality of time embeddings.
|
| 146 |
-
**kwargs: Additional keyword arguments for the transformer.
|
| 147 |
-
Methods:
|
| 148 |
-
__init__: Initializes the model with the specified attributes and configuration.
|
| 149 |
-
"""
|
| 150 |
-
def __init__(self, condition_provider: JascoConditioningProvider,
|
| 151 |
-
fuser: ConditionFuser,
|
| 152 |
-
dim: int = 128,
|
| 153 |
-
num_heads: int = 8,
|
| 154 |
-
flow_dim: int = 128,
|
| 155 |
-
chords_dim: int = 0,
|
| 156 |
-
drums_dim: int = 0,
|
| 157 |
-
melody_dim: int = 0,
|
| 158 |
-
hidden_scale: int = 4,
|
| 159 |
-
norm: str = 'layer_norm',
|
| 160 |
-
norm_first: bool = False,
|
| 161 |
-
bias_proj: bool = True,
|
| 162 |
-
weight_init: tp.Optional[str] = None,
|
| 163 |
-
depthwise_init: tp.Optional[str] = None,
|
| 164 |
-
zero_bias_init: bool = False,
|
| 165 |
-
cfg_dropout: float = 0,
|
| 166 |
-
cfg_coef: float = 1.0,
|
| 167 |
-
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {},
|
| 168 |
-
time_embedding_dim: int = 128,
|
| 169 |
-
**kwargs):
|
| 170 |
-
super().__init__()
|
| 171 |
-
self.cfg_coef = cfg_coef
|
| 172 |
-
|
| 173 |
-
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
|
| 174 |
-
self.att_dropout = AttributeDropout(p=attribute_dropout)
|
| 175 |
-
self.condition_provider = condition_provider
|
| 176 |
-
self.fuser = fuser
|
| 177 |
-
self.dim = dim # transformer dim
|
| 178 |
-
self.flow_dim = flow_dim
|
| 179 |
-
self.chords_dim = chords_dim
|
| 180 |
-
self.emb = nn.Linear(flow_dim + chords_dim + drums_dim + melody_dim, dim, bias=False)
|
| 181 |
-
if 'activation' in kwargs:
|
| 182 |
-
kwargs['activation'] = get_activation_fn(kwargs['activation'])
|
| 183 |
-
|
| 184 |
-
self.transformer = UnetTransformer(
|
| 185 |
-
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
|
| 186 |
-
norm=norm, norm_first=norm_first,
|
| 187 |
-
layer_class=StreamingTransformerLayer,
|
| 188 |
-
**kwargs)
|
| 189 |
-
self.out_norm: tp.Optional[nn.Module] = None
|
| 190 |
-
if norm_first:
|
| 191 |
-
self.out_norm = create_norm_fn(norm, dim)
|
| 192 |
-
self.linear = nn.Linear(dim, flow_dim, bias=bias_proj)
|
| 193 |
-
self._init_weights(weight_init, depthwise_init, zero_bias_init)
|
| 194 |
-
self._fsdp: tp.Optional[nn.Module]
|
| 195 |
-
self.__dict__['_fsdp'] = None
|
| 196 |
-
|
| 197 |
-
# init time parameter embedding
|
| 198 |
-
self.d_temb1 = time_embedding_dim
|
| 199 |
-
self.d_temb2 = 4 * time_embedding_dim
|
| 200 |
-
self.temb = nn.Module()
|
| 201 |
-
self.temb.dense = nn.ModuleList([
|
| 202 |
-
torch.nn.Linear(self.d_temb1,
|
| 203 |
-
self.d_temb2),
|
| 204 |
-
torch.nn.Linear(self.d_temb2,
|
| 205 |
-
self.d_temb2),
|
| 206 |
-
])
|
| 207 |
-
self.temb_proj = nn.Linear(self.d_temb2, dim)
|
| 208 |
-
|
| 209 |
-
def _get_timestep_embedding(self, timesteps, embedding_dim):
|
| 210 |
-
"""
|
| 211 |
-
#######################################################################################################
|
| 212 |
-
TAKEN FROM: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py
|
| 213 |
-
#######################################################################################################
|
| 214 |
-
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
| 215 |
-
From Fairseq.
|
| 216 |
-
Build sinusoidal embeddings.
|
| 217 |
-
This matches the implementation in tensor2tensor, but differs slightly
|
| 218 |
-
from the description in Section 3.5 of "Attention Is All You Need".
|
| 219 |
-
"""
|
| 220 |
-
assert len(timesteps.shape) == 1
|
| 221 |
-
|
| 222 |
-
half_dim = embedding_dim // 2
|
| 223 |
-
emb = math.log(10000) / (half_dim - 1)
|
| 224 |
-
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
| 225 |
-
emb = emb.to(device=timesteps.device)
|
| 226 |
-
emb = timesteps.float()[:, None] * emb[None, :]
|
| 227 |
-
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 228 |
-
if embedding_dim % 2 == 1: # zero pad
|
| 229 |
-
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 230 |
-
return emb
|
| 231 |
-
|
| 232 |
-
def _embed_time_parameter(self, t: torch.Tensor):
|
| 233 |
-
"""
|
| 234 |
-
#######################################################################################################
|
| 235 |
-
TAKEN FROM: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py
|
| 236 |
-
#######################################################################################################
|
| 237 |
-
"""
|
| 238 |
-
temb = self._get_timestep_embedding(t.flatten(), self.d_temb1)
|
| 239 |
-
temb = self.temb.dense[0](temb)
|
| 240 |
-
temb = temb * torch.sigmoid(temb) # swish activation
|
| 241 |
-
temb = self.temb.dense[1](temb)
|
| 242 |
-
return temb
|
| 243 |
-
|
| 244 |
-
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
|
| 245 |
-
"""Initialization of the transformer module weights.
|
| 246 |
-
|
| 247 |
-
Args:
|
| 248 |
-
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
|
| 249 |
-
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
|
| 250 |
-
'current' where the depth corresponds to the current layer index or 'global' where the total number
|
| 251 |
-
of layer is used as depth. If not set, no depthwise initialization strategy is used.
|
| 252 |
-
zero_bias_init (bool): Whether to initialize bias to zero or not.
|
| 253 |
-
"""
|
| 254 |
-
assert depthwise_init is None or depthwise_init in ['current', 'global']
|
| 255 |
-
assert depthwise_init is None or weight_init is not None, \
|
| 256 |
-
"If 'depthwise_init' is defined, a 'weight_init' method should be provided."
|
| 257 |
-
assert not zero_bias_init or weight_init is not None, \
|
| 258 |
-
"If 'zero_bias_init', a 'weight_init' method should be provided"
|
| 259 |
-
|
| 260 |
-
if weight_init is None:
|
| 261 |
-
return
|
| 262 |
-
|
| 263 |
-
init_layer(self.emb, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
| 264 |
-
|
| 265 |
-
for layer_idx, tr_layer in enumerate(self.transformer.layers):
|
| 266 |
-
depth = None
|
| 267 |
-
if depthwise_init == 'current':
|
| 268 |
-
depth = layer_idx + 1
|
| 269 |
-
elif depthwise_init == 'global':
|
| 270 |
-
depth = len(self.transformer.layers)
|
| 271 |
-
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
|
| 272 |
-
tr_layer.apply(init_fn)
|
| 273 |
-
|
| 274 |
-
init_layer(self.linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
| 275 |
-
|
| 276 |
-
def _align_seq_length(self,
|
| 277 |
-
cond: torch.Tensor,
|
| 278 |
-
seq_len: int = 500):
|
| 279 |
-
# trim if needed
|
| 280 |
-
cond = cond[:, :seq_len, :]
|
| 281 |
-
|
| 282 |
-
# pad if needed
|
| 283 |
-
B, T, C = cond.shape
|
| 284 |
-
if T < seq_len:
|
| 285 |
-
cond = torch.cat((cond, torch.zeros((B, seq_len - T, C), dtype=cond.dtype, device=cond.device)), dim=1)
|
| 286 |
-
|
| 287 |
-
return cond
|
| 288 |
-
|
| 289 |
-
def forward(self,
|
| 290 |
-
latents: torch.Tensor,
|
| 291 |
-
t: torch.Tensor,
|
| 292 |
-
conditions: tp.List[ConditioningAttributes],
|
| 293 |
-
condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
|
| 294 |
-
"""Apply flow matching forward pass on latents and conditions.
|
| 295 |
-
Given a tensor of noisy latents of shape [B, T, D] with D the flow dim and T the sequence steps,
|
| 296 |
-
and a time parameter tensor t, return the vector field with shape [B, T, D].
|
| 297 |
-
|
| 298 |
-
Args:
|
| 299 |
-
latents (torch.Tensor): noisy latents.
|
| 300 |
-
conditions (list of ConditioningAttributes): Conditions to use when modeling
|
| 301 |
-
the given codes. Note that when evaluating multiple time with the same conditioning
|
| 302 |
-
you should pre-compute those and pass them as `condition_tensors`.
|
| 303 |
-
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
|
| 304 |
-
tensors, see `conditions`.
|
| 305 |
-
Returns:
|
| 306 |
-
torch.Tensor: estimated vector field v_theta.
|
| 307 |
-
"""
|
| 308 |
-
assert condition_tensors is not None, "FlowMatchingModel require pre-calculation of condition tensors"
|
| 309 |
-
assert not conditions, "Shouldn't pass unprocessed conditions to FlowMatchingModel."
|
| 310 |
-
|
| 311 |
-
B, T, D = latents.shape
|
| 312 |
-
x = latents
|
| 313 |
-
|
| 314 |
-
# concat temporal conditions on the feature dimension
|
| 315 |
-
temporal_conds = JascoCondConst.ALL.value
|
| 316 |
-
for cond in temporal_conds:
|
| 317 |
-
if cond not in condition_tensors:
|
| 318 |
-
continue
|
| 319 |
-
c = self._align_seq_length(condition_tensors[cond][0], seq_len=T)
|
| 320 |
-
x = torch.concat((x, c), dim=-1)
|
| 321 |
-
|
| 322 |
-
# project to transformer dimension
|
| 323 |
-
input_ = self.emb(x)
|
| 324 |
-
|
| 325 |
-
input_, cross_attention_input = self.fuser(input_, condition_tensors)
|
| 326 |
-
|
| 327 |
-
# embed time parameter
|
| 328 |
-
t_embs = self._embed_time_parameter(t)
|
| 329 |
-
|
| 330 |
-
# add it to cross_attention_input
|
| 331 |
-
cross_attention_input = cross_attention_input + self.temb_proj(t_embs[:, None, :])
|
| 332 |
-
|
| 333 |
-
out = self.transformer(input_, cross_attention_src=cross_attention_input)
|
| 334 |
-
|
| 335 |
-
if self.out_norm:
|
| 336 |
-
out = self.out_norm(out)
|
| 337 |
-
v_theta = self.linear(out) # [B, T, D]
|
| 338 |
-
|
| 339 |
-
# remove the prefix from the model outputs
|
| 340 |
-
if len(self.fuser.fuse2cond['prepend']) > 0:
|
| 341 |
-
v_theta = v_theta[:, :, -T:]
|
| 342 |
-
|
| 343 |
-
return v_theta # [B, T, D]
|
| 344 |
-
|
| 345 |
-
def _multi_source_cfg_preprocess(self,
|
| 346 |
-
conditions: tp.List[ConditioningAttributes],
|
| 347 |
-
cfg_coef_all: float,
|
| 348 |
-
cfg_coef_txt: float,
|
| 349 |
-
min_weight: float = 1e-6):
|
| 350 |
-
"""
|
| 351 |
-
Preprocesses the CFG terms for multi-source conditional generation.
|
| 352 |
-
Args:
|
| 353 |
-
conditions (list): A list of conditions to be applied.
|
| 354 |
-
cfg_coef_all (float): The coefficient for all conditions.
|
| 355 |
-
cfg_coef_txt (float): The coefficient for text conditions.
|
| 356 |
-
min_weight (float): The minimal absolute weight for calculating a CFG term.
|
| 357 |
-
Returns:
|
| 358 |
-
tuple: A tuple containing condition_tensors and cfg_terms.
|
| 359 |
-
condition_tensors is a dictionary or ConditionTensors object with tokenized conditions.
|
| 360 |
-
cfg_terms is a list of CFGTerm objects with weights adjusted based on the coefficients.
|
| 361 |
-
"""
|
| 362 |
-
condition_tensors: tp.Optional[ConditionTensors]
|
| 363 |
-
cfg_terms = []
|
| 364 |
-
if conditions:
|
| 365 |
-
# conditional terms
|
| 366 |
-
cfg_terms = [AllCFGTerm(conditions=conditions, weight=cfg_coef_all),
|
| 367 |
-
TextCFGTerm(conditions=conditions, weight=cfg_coef_txt,
|
| 368 |
-
model_att_dropout=self.att_dropout)]
|
| 369 |
-
|
| 370 |
-
# add null term
|
| 371 |
-
cfg_terms.append(NullCFGTerm(conditions=conditions, weight=1 - sum([ct.weight for ct in cfg_terms])))
|
| 372 |
-
|
| 373 |
-
# remove terms with negligible weight
|
| 374 |
-
for ct in cfg_terms:
|
| 375 |
-
if abs(ct.weight) < min_weight:
|
| 376 |
-
cfg_terms.remove(ct)
|
| 377 |
-
|
| 378 |
-
conds: tp.List[ConditioningAttributes] = sum([ct.conditions for ct in cfg_terms], [])
|
| 379 |
-
tokenized = self.condition_provider.tokenize(conds)
|
| 380 |
-
condition_tensors = self.condition_provider(tokenized)
|
| 381 |
-
else:
|
| 382 |
-
condition_tensors = {}
|
| 383 |
-
|
| 384 |
-
return condition_tensors, cfg_terms
|
| 385 |
-
|
| 386 |
-
def estimated_vector_field(self, z, t, condition_tensors=None, cfg_terms=[]):
|
| 387 |
-
"""
|
| 388 |
-
Estimates the vector field for the given latent variables and time parameter,
|
| 389 |
-
conditioned on the provided conditions.
|
| 390 |
-
Args:
|
| 391 |
-
z (Tensor): The latent variables.
|
| 392 |
-
t (float): The time variable.
|
| 393 |
-
condition_tensors (ConditionTensors, optional): The condition tensors. Defaults to None.
|
| 394 |
-
cfg_terms (list, optional): The list of CFG terms. Defaults to an empty list.
|
| 395 |
-
Returns:
|
| 396 |
-
Tensor: The estimated vector field.
|
| 397 |
-
"""
|
| 398 |
-
if len(cfg_terms) > 1:
|
| 399 |
-
z = z.repeat(len(cfg_terms), 1, 1) # duplicate noisy latents for multi-source CFG
|
| 400 |
-
v_thetas = self(latents=z, t=t, conditions=[], condition_tensors=condition_tensors)
|
| 401 |
-
return self._multi_source_cfg_postprocess(v_thetas, cfg_terms)
|
| 402 |
-
|
| 403 |
-
def _multi_source_cfg_postprocess(self, v_thetas, cfg_terms):
|
| 404 |
-
"""
|
| 405 |
-
Postprocesses the vector fields generated for each CFG term to combine them into a single vector field.
|
| 406 |
-
Multi source guidance occurs here.
|
| 407 |
-
Args:
|
| 408 |
-
v_thetas (Tensor): The vector fields for each CFG term.
|
| 409 |
-
cfg_terms (list): The CFG terms used.
|
| 410 |
-
Returns:
|
| 411 |
-
Tensor: The combined vector field.
|
| 412 |
-
"""
|
| 413 |
-
if len(cfg_terms) <= 1:
|
| 414 |
-
return v_thetas
|
| 415 |
-
v_theta_per_term = v_thetas.chunk(len(cfg_terms))
|
| 416 |
-
return sum([ct.weight * term_vf for ct, term_vf in zip(cfg_terms, v_theta_per_term)])
|
| 417 |
-
|
| 418 |
-
@torch.no_grad()
|
| 419 |
-
def generate(self,
|
| 420 |
-
prompt: tp.Optional[torch.Tensor] = None,
|
| 421 |
-
conditions: tp.List[ConditioningAttributes] = [],
|
| 422 |
-
num_samples: tp.Optional[int] = None,
|
| 423 |
-
max_gen_len: int = 256,
|
| 424 |
-
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
| 425 |
-
cfg_coef_all: float = 3.0,
|
| 426 |
-
cfg_coef_txt: float = 1.0,
|
| 427 |
-
euler: bool = False,
|
| 428 |
-
euler_steps: int = 100,
|
| 429 |
-
ode_rtol: float = 1e-5,
|
| 430 |
-
ode_atol: float = 1e-5,
|
| 431 |
-
) -> torch.Tensor:
|
| 432 |
-
"""
|
| 433 |
-
Generate audio latents given a prompt or unconditionally. This method supports both Euler integration
|
| 434 |
-
and adaptive ODE solving to generate sequences based on the specified conditions and configuration coefficients.
|
| 435 |
-
|
| 436 |
-
Args:
|
| 437 |
-
prompt (torch.Tensor, optional): Initial prompt to condition the generation. defaults to None
|
| 438 |
-
conditions (List[ConditioningAttributes]): List of conditioning attributes - text, symbolic or audio.
|
| 439 |
-
num_samples (int, optional): Number of samples to generate.
|
| 440 |
-
If None, it is inferred from the number of conditions.
|
| 441 |
-
max_gen_len (int): Maximum length of the generated sequence.
|
| 442 |
-
callback (Callable[[int, int], None], optional): Callback function to monitor the generation process.
|
| 443 |
-
cfg_coef_all (float): Coefficient for the fully conditional CFG term.
|
| 444 |
-
cfg_coef_txt (float): Coefficient for text CFG term.
|
| 445 |
-
euler (bool): If True, use Euler integration, otherwise use adaptive ODE solver.
|
| 446 |
-
euler_steps (int): Number of Euler steps to perform if Euler integration is used.
|
| 447 |
-
ode_rtol (float): ODE solver rtol threshold.
|
| 448 |
-
ode_atol (float): ODE solver atol threshold.
|
| 449 |
-
|
| 450 |
-
Returns:
|
| 451 |
-
torch.Tensor: Generated latents, shaped as (num_samples, max_gen_len, feature_dim).
|
| 452 |
-
"""
|
| 453 |
-
|
| 454 |
-
assert not self.training, "generation shouldn't be used in training mode."
|
| 455 |
-
first_param = next(iter(self.parameters()))
|
| 456 |
-
device = first_param.device
|
| 457 |
-
|
| 458 |
-
# Checking all input shapes are consistent.
|
| 459 |
-
possible_num_samples = []
|
| 460 |
-
if num_samples is not None:
|
| 461 |
-
possible_num_samples.append(num_samples)
|
| 462 |
-
elif prompt is not None:
|
| 463 |
-
possible_num_samples.append(prompt.shape[0])
|
| 464 |
-
elif conditions:
|
| 465 |
-
possible_num_samples.append(len(conditions))
|
| 466 |
-
else:
|
| 467 |
-
possible_num_samples.append(1)
|
| 468 |
-
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
|
| 469 |
-
num_samples = possible_num_samples[0]
|
| 470 |
-
|
| 471 |
-
condition_tensors, cfg_terms = self._multi_source_cfg_preprocess(conditions, cfg_coef_all, cfg_coef_txt)
|
| 472 |
-
|
| 473 |
-
# flow matching inference
|
| 474 |
-
B, T, D = num_samples, max_gen_len, self.flow_dim
|
| 475 |
-
|
| 476 |
-
z_0 = torch.randn((B, T, D), device=device)
|
| 477 |
-
|
| 478 |
-
if euler:
|
| 479 |
-
# vanilla Euler intergration
|
| 480 |
-
dt = (1 / euler_steps)
|
| 481 |
-
z = z_0
|
| 482 |
-
t = torch.zeros((1, ), device=device)
|
| 483 |
-
for _ in range(euler_steps):
|
| 484 |
-
v_theta = self.estimated_vector_field(z, t,
|
| 485 |
-
condition_tensors=condition_tensors,
|
| 486 |
-
cfg_terms=cfg_terms)
|
| 487 |
-
z = z + dt * v_theta
|
| 488 |
-
t = t + dt
|
| 489 |
-
z_1 = z
|
| 490 |
-
else:
|
| 491 |
-
# solve with dynamic ode integrator (dopri5)
|
| 492 |
-
t = torch.tensor([0, 1.0 - 1e-5], device=device)
|
| 493 |
-
num_evals = 0
|
| 494 |
-
|
| 495 |
-
# define ode vector field function
|
| 496 |
-
def inner_ode_func(t, z):
|
| 497 |
-
nonlocal num_evals
|
| 498 |
-
num_evals += 1
|
| 499 |
-
if callback is not None:
|
| 500 |
-
ESTIMATED_ODE_SOLVER_STEPS = 300
|
| 501 |
-
callback(num_evals, ESTIMATED_ODE_SOLVER_STEPS)
|
| 502 |
-
return self.estimated_vector_field(z, t,
|
| 503 |
-
condition_tensors=condition_tensors,
|
| 504 |
-
cfg_terms=cfg_terms)
|
| 505 |
-
|
| 506 |
-
ode_opts: dict = {"options": {}}
|
| 507 |
-
z = odeint(
|
| 508 |
-
inner_ode_func,
|
| 509 |
-
z_0,
|
| 510 |
-
t,
|
| 511 |
-
**{"atol": ode_atol, "rtol": ode_rtol, **ode_opts},
|
| 512 |
-
)
|
| 513 |
-
logger.info("Generated in %d steps", num_evals)
|
| 514 |
-
z_1 = z[-1]
|
| 515 |
-
|
| 516 |
-
return z_1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/models/genmodel.py
DELETED
|
@@ -1,273 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Base implementation for audio generative models. This base implementation
|
| 9 |
-
combines all the required components to run inference with pretrained audio
|
| 10 |
-
generative models. It can be easily inherited by downstream model classes to
|
| 11 |
-
provide easy access to the generation API.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
from abc import ABC, abstractmethod
|
| 15 |
-
import typing as tp
|
| 16 |
-
|
| 17 |
-
import omegaconf
|
| 18 |
-
import torch
|
| 19 |
-
import gradio as gr
|
| 20 |
-
|
| 21 |
-
from .encodec import CompressionModel
|
| 22 |
-
from .lm import LMModel
|
| 23 |
-
from .builders import get_wrapped_compression_model
|
| 24 |
-
from ..data.audio_utils import convert_audio
|
| 25 |
-
from ..modules.conditioners import ConditioningAttributes
|
| 26 |
-
from ..utils.autocast import TorchAutocast
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class BaseGenModel(ABC):
|
| 30 |
-
"""Base generative model with convenient generation API.
|
| 31 |
-
|
| 32 |
-
Args:
|
| 33 |
-
name (str): name of the model.
|
| 34 |
-
compression_model (CompressionModel): Compression model
|
| 35 |
-
used to map audio to invertible discrete representations.
|
| 36 |
-
lm (LMModel): Language model over discrete representations.
|
| 37 |
-
max_duration (float, optional): maximum duration the model can produce,
|
| 38 |
-
otherwise, inferred from the training params.
|
| 39 |
-
"""
|
| 40 |
-
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
| 41 |
-
max_duration: tp.Optional[float] = None):
|
| 42 |
-
self.name = name
|
| 43 |
-
self.compression_model = compression_model
|
| 44 |
-
self.lm = lm
|
| 45 |
-
self.cfg: tp.Optional[omegaconf.DictConfig] = None
|
| 46 |
-
# Just to be safe, let's put everything in eval mode.
|
| 47 |
-
self.compression_model.eval()
|
| 48 |
-
self.lm.eval()
|
| 49 |
-
|
| 50 |
-
if hasattr(lm, 'cfg'):
|
| 51 |
-
cfg = lm.cfg
|
| 52 |
-
assert isinstance(cfg, omegaconf.DictConfig)
|
| 53 |
-
self.cfg = cfg
|
| 54 |
-
|
| 55 |
-
if self.cfg is not None:
|
| 56 |
-
self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
|
| 57 |
-
|
| 58 |
-
if max_duration is None:
|
| 59 |
-
if self.cfg is not None:
|
| 60 |
-
max_duration = lm.cfg.dataset.segment_duration # type: ignore
|
| 61 |
-
else:
|
| 62 |
-
raise ValueError("You must provide max_duration when building directly your GenModel")
|
| 63 |
-
assert max_duration is not None
|
| 64 |
-
|
| 65 |
-
self.max_duration: float = max_duration
|
| 66 |
-
self.duration = self.max_duration
|
| 67 |
-
|
| 68 |
-
# self.extend_stride is the length of audio extension when generating samples longer
|
| 69 |
-
# than self.max_duration. NOTE: the derived class must set self.extend_stride to a
|
| 70 |
-
# positive float value when generating with self.duration > self.max_duration.
|
| 71 |
-
self.extend_stride: tp.Optional[float] = None
|
| 72 |
-
self.device = next(iter(lm.parameters())).device
|
| 73 |
-
self.generation_params: dict = {}
|
| 74 |
-
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
| 75 |
-
if self.device.type == 'cpu':
|
| 76 |
-
self.autocast = TorchAutocast(enabled=False)
|
| 77 |
-
else:
|
| 78 |
-
self.autocast = TorchAutocast(
|
| 79 |
-
enabled=True, device_type=self.device.type, dtype=torch.float16)
|
| 80 |
-
|
| 81 |
-
@property
|
| 82 |
-
def frame_rate(self) -> float:
|
| 83 |
-
"""Roughly the number of AR steps per seconds."""
|
| 84 |
-
return self.compression_model.frame_rate
|
| 85 |
-
|
| 86 |
-
@property
|
| 87 |
-
def sample_rate(self) -> int:
|
| 88 |
-
"""Sample rate of the generated audio."""
|
| 89 |
-
return self.compression_model.sample_rate
|
| 90 |
-
|
| 91 |
-
@property
|
| 92 |
-
def audio_channels(self) -> int:
|
| 93 |
-
"""Audio channels of the generated audio."""
|
| 94 |
-
return self.compression_model.channels
|
| 95 |
-
|
| 96 |
-
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
| 97 |
-
"""Override the default progress callback."""
|
| 98 |
-
self._progress_callback = progress_callback
|
| 99 |
-
|
| 100 |
-
@abstractmethod
|
| 101 |
-
def set_generation_params(self, *args, **kwargs):
|
| 102 |
-
"""Set the generation parameters."""
|
| 103 |
-
raise NotImplementedError("No base implementation for setting generation params.")
|
| 104 |
-
|
| 105 |
-
@staticmethod
|
| 106 |
-
@abstractmethod
|
| 107 |
-
def get_pretrained(name: str, device=None):
|
| 108 |
-
raise NotImplementedError("No base implementation for getting pretrained model")
|
| 109 |
-
|
| 110 |
-
@torch.no_grad()
|
| 111 |
-
def _prepare_tokens_and_attributes(
|
| 112 |
-
self,
|
| 113 |
-
descriptions: tp.Sequence[tp.Optional[str]],
|
| 114 |
-
prompt: tp.Optional[torch.Tensor],
|
| 115 |
-
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
|
| 116 |
-
"""Prepare model inputs.
|
| 117 |
-
|
| 118 |
-
Args:
|
| 119 |
-
descriptions (list of str): A list of strings used as text conditioning.
|
| 120 |
-
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
| 121 |
-
"""
|
| 122 |
-
attributes = [
|
| 123 |
-
ConditioningAttributes(text={'description': description})
|
| 124 |
-
for description in descriptions]
|
| 125 |
-
|
| 126 |
-
if prompt is not None:
|
| 127 |
-
if descriptions is not None:
|
| 128 |
-
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
|
| 129 |
-
prompt = prompt.to(self.device)
|
| 130 |
-
prompt_tokens, scale = self.compression_model.encode(prompt)
|
| 131 |
-
assert scale is None
|
| 132 |
-
else:
|
| 133 |
-
prompt_tokens = None
|
| 134 |
-
return attributes, prompt_tokens
|
| 135 |
-
|
| 136 |
-
def generate_unconditional(self, num_samples: int, progress: bool = False,
|
| 137 |
-
return_tokens: bool = False) -> tp.Union[torch.Tensor,
|
| 138 |
-
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
| 139 |
-
"""Generate samples in an unconditional manner.
|
| 140 |
-
|
| 141 |
-
Args:
|
| 142 |
-
num_samples (int): Number of samples to be generated.
|
| 143 |
-
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 144 |
-
"""
|
| 145 |
-
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
|
| 146 |
-
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
| 147 |
-
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
| 148 |
-
if return_tokens:
|
| 149 |
-
return self.generate_audio(tokens), tokens
|
| 150 |
-
return self.generate_audio(tokens)
|
| 151 |
-
|
| 152 |
-
def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
|
| 153 |
-
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
| 154 |
-
"""Generate samples conditioned on text.
|
| 155 |
-
|
| 156 |
-
Args:
|
| 157 |
-
descriptions (list of str): A list of strings used as text conditioning.
|
| 158 |
-
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 159 |
-
"""
|
| 160 |
-
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
| 161 |
-
assert prompt_tokens is None
|
| 162 |
-
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
| 163 |
-
if return_tokens:
|
| 164 |
-
return self.generate_audio(tokens), tokens
|
| 165 |
-
return self.generate_audio(tokens)
|
| 166 |
-
|
| 167 |
-
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
| 168 |
-
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
| 169 |
-
progress: bool = False, return_tokens: bool = False) \
|
| 170 |
-
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
| 171 |
-
"""Generate samples conditioned on audio prompts and an optional text description.
|
| 172 |
-
|
| 173 |
-
Args:
|
| 174 |
-
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
| 175 |
-
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
| 176 |
-
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
|
| 177 |
-
descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
|
| 178 |
-
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 179 |
-
"""
|
| 180 |
-
if prompt.dim() == 2:
|
| 181 |
-
prompt = prompt[None]
|
| 182 |
-
if prompt.dim() != 3:
|
| 183 |
-
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
|
| 184 |
-
prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
|
| 185 |
-
if descriptions is None:
|
| 186 |
-
descriptions = [None] * len(prompt)
|
| 187 |
-
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
| 188 |
-
assert prompt_tokens is not None
|
| 189 |
-
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
| 190 |
-
if return_tokens:
|
| 191 |
-
return self.generate_audio(tokens), tokens
|
| 192 |
-
return self.generate_audio(tokens)
|
| 193 |
-
|
| 194 |
-
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
| 195 |
-
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False, progress_callback: gr.Progress = None) -> torch.Tensor:
|
| 196 |
-
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
| 197 |
-
|
| 198 |
-
Args:
|
| 199 |
-
attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
|
| 200 |
-
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
| 201 |
-
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 202 |
-
Returns:
|
| 203 |
-
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
| 204 |
-
"""
|
| 205 |
-
total_gen_len = int(self.duration * self.frame_rate)
|
| 206 |
-
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
| 207 |
-
current_gen_offset: int = 0
|
| 208 |
-
|
| 209 |
-
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
| 210 |
-
generated_tokens += current_gen_offset
|
| 211 |
-
generated_tokens /= ((tokens_to_generate) / self.duration)
|
| 212 |
-
tokens_to_generate /= ((tokens_to_generate) / self.duration)
|
| 213 |
-
if self._progress_callback is not None:
|
| 214 |
-
# Note that total_gen_len might be quite wrong depending on the
|
| 215 |
-
# codebook pattern used, but with delay it is almost accurate.
|
| 216 |
-
self._progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens: 6.2f}/{tokens_to_generate: 6.2f} seconds")
|
| 217 |
-
if progress_callback is not None:
|
| 218 |
-
# Update Gradio progress bar
|
| 219 |
-
progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens: 6.2f}/{tokens_to_generate: 6.2f} seconds")
|
| 220 |
-
if progress:
|
| 221 |
-
print(f'{generated_tokens: 6.2f} / {tokens_to_generate: 6.2f}', end='\r')
|
| 222 |
-
|
| 223 |
-
if prompt_tokens is not None:
|
| 224 |
-
if prompt_tokens.shape[-1] > max_prompt_len:
|
| 225 |
-
prompt_tokens = prompt_tokens[..., :max_prompt_len]
|
| 226 |
-
|
| 227 |
-
# callback = None
|
| 228 |
-
callback = _progress_callback
|
| 229 |
-
|
| 230 |
-
if self.duration <= self.max_duration:
|
| 231 |
-
# generate by sampling from LM, simple case.
|
| 232 |
-
with self.autocast:
|
| 233 |
-
gen_tokens = self.lm.generate(
|
| 234 |
-
prompt_tokens, attributes,
|
| 235 |
-
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
|
| 236 |
-
|
| 237 |
-
else:
|
| 238 |
-
assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration"
|
| 239 |
-
assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
| 240 |
-
all_tokens = []
|
| 241 |
-
if prompt_tokens is None:
|
| 242 |
-
prompt_length = 0
|
| 243 |
-
else:
|
| 244 |
-
all_tokens.append(prompt_tokens)
|
| 245 |
-
prompt_length = prompt_tokens.shape[-1]
|
| 246 |
-
|
| 247 |
-
stride_tokens = int(self.frame_rate * self.extend_stride)
|
| 248 |
-
|
| 249 |
-
while current_gen_offset + prompt_length < total_gen_len:
|
| 250 |
-
time_offset = current_gen_offset / self.frame_rate
|
| 251 |
-
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
| 252 |
-
max_gen_len = int(chunk_duration * self.frame_rate)
|
| 253 |
-
with self.autocast:
|
| 254 |
-
gen_tokens = self.lm.generate(
|
| 255 |
-
prompt_tokens, attributes,
|
| 256 |
-
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
| 257 |
-
if prompt_tokens is None:
|
| 258 |
-
all_tokens.append(gen_tokens)
|
| 259 |
-
else:
|
| 260 |
-
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
| 261 |
-
prompt_tokens = gen_tokens[:, :, stride_tokens:]
|
| 262 |
-
prompt_length = prompt_tokens.shape[-1]
|
| 263 |
-
current_gen_offset += stride_tokens
|
| 264 |
-
|
| 265 |
-
gen_tokens = torch.cat(all_tokens, dim=-1)
|
| 266 |
-
return gen_tokens
|
| 267 |
-
|
| 268 |
-
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
|
| 269 |
-
"""Generate Audio from tokens."""
|
| 270 |
-
assert gen_tokens.dim() == 3
|
| 271 |
-
with torch.no_grad():
|
| 272 |
-
gen_audio = self.compression_model.decode(gen_tokens, None)
|
| 273 |
-
return gen_audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/models/lm.py
DELETED
|
@@ -1,588 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from dataclasses import dataclass
|
| 8 |
-
from functools import partial
|
| 9 |
-
import logging
|
| 10 |
-
import math
|
| 11 |
-
import typing as tp
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
from torch import nn
|
| 15 |
-
|
| 16 |
-
from ..utils import utils
|
| 17 |
-
from ..modules.streaming import StreamingModule, State
|
| 18 |
-
from ..modules.transformer import StreamingTransformer, create_norm_fn
|
| 19 |
-
from ..modules.conditioners import (
|
| 20 |
-
ConditionFuser,
|
| 21 |
-
ClassifierFreeGuidanceDropout,
|
| 22 |
-
AttributeDropout,
|
| 23 |
-
ConditioningProvider,
|
| 24 |
-
ConditioningAttributes,
|
| 25 |
-
ConditionType,
|
| 26 |
-
_drop_description_condition
|
| 27 |
-
)
|
| 28 |
-
from ..modules.codebooks_patterns import CodebooksPatternProvider
|
| 29 |
-
from ..modules.activations import get_activation_fn
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
logger = logging.getLogger(__name__)
|
| 33 |
-
ConditionTensors = tp.Dict[str, ConditionType]
|
| 34 |
-
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
|
| 38 |
-
"""LM layer initialization.
|
| 39 |
-
Inspired from xlformers: https://github.com/fairinternal/xlformers
|
| 40 |
-
|
| 41 |
-
Args:
|
| 42 |
-
method (str): Method name for init function. Valid options are:
|
| 43 |
-
'gaussian', 'uniform'.
|
| 44 |
-
input_dim (int): Input dimension of the initialized module.
|
| 45 |
-
init_depth (int, optional): Optional init depth value used to rescale
|
| 46 |
-
the standard deviation if defined.
|
| 47 |
-
"""
|
| 48 |
-
# Compute std
|
| 49 |
-
std = 1 / math.sqrt(input_dim)
|
| 50 |
-
# Rescale with depth
|
| 51 |
-
if init_depth is not None:
|
| 52 |
-
std = std / math.sqrt(2 * init_depth)
|
| 53 |
-
|
| 54 |
-
if method == 'gaussian':
|
| 55 |
-
return partial(
|
| 56 |
-
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
|
| 57 |
-
)
|
| 58 |
-
elif method == 'uniform':
|
| 59 |
-
bound = math.sqrt(3) * std # ensure the standard deviation is `std`
|
| 60 |
-
return partial(torch.nn.init.uniform_, a=-bound, b=bound)
|
| 61 |
-
else:
|
| 62 |
-
raise ValueError("Unsupported layer initialization method")
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def init_layer(m: nn.Module,
|
| 66 |
-
method: str,
|
| 67 |
-
init_depth: tp.Optional[int] = None,
|
| 68 |
-
zero_bias_init: bool = False):
|
| 69 |
-
"""Wrapper around ``get_init_fn`` for proper initialization of LM modules.
|
| 70 |
-
|
| 71 |
-
Args:
|
| 72 |
-
m (nn.Module): Module to initialize.
|
| 73 |
-
method (str): Method name for the init function.
|
| 74 |
-
init_depth (int, optional): Optional init depth value used to rescale
|
| 75 |
-
the standard deviation if defined.
|
| 76 |
-
zero_bias_init (bool): Whether to initialize the bias to 0 or not.
|
| 77 |
-
"""
|
| 78 |
-
if isinstance(m, nn.Linear):
|
| 79 |
-
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
|
| 80 |
-
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
| 81 |
-
weight = m.weight.float()
|
| 82 |
-
init_fn(weight)
|
| 83 |
-
m.weight.data[:] = weight.half()
|
| 84 |
-
else:
|
| 85 |
-
init_fn(m.weight)
|
| 86 |
-
if zero_bias_init and m.bias is not None:
|
| 87 |
-
nn.init.constant_(m.bias, 0)
|
| 88 |
-
elif isinstance(m, nn.Embedding):
|
| 89 |
-
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
|
| 90 |
-
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
| 91 |
-
weight = m.weight.float()
|
| 92 |
-
init_fn(weight)
|
| 93 |
-
m.weight.data[:] = weight.half()
|
| 94 |
-
else:
|
| 95 |
-
init_fn(m.weight)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
class ScaledEmbedding(nn.Embedding):
|
| 99 |
-
"""Boost learning rate for embeddings (with `scale`).
|
| 100 |
-
"""
|
| 101 |
-
def __init__(self, *args, lr=None, **kwargs):
|
| 102 |
-
super().__init__(*args, **kwargs)
|
| 103 |
-
self.lr = lr
|
| 104 |
-
|
| 105 |
-
def make_optim_group(self):
|
| 106 |
-
group = {"params": list(self.parameters())}
|
| 107 |
-
if self.lr is not None:
|
| 108 |
-
group["lr"] = self.lr
|
| 109 |
-
return group
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
@dataclass
|
| 113 |
-
class LMOutput:
|
| 114 |
-
# The logits are already re-aligned with the input codes
|
| 115 |
-
# hence no extra shift is required, e.g. when computing CE
|
| 116 |
-
logits: torch.Tensor # [B, K, T, card]
|
| 117 |
-
mask: torch.Tensor # [B, K, T]
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
class LMModel(StreamingModule):
|
| 121 |
-
"""Transformer-based language model on multiple streams of codes.
|
| 122 |
-
|
| 123 |
-
Args:
|
| 124 |
-
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
|
| 125 |
-
condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
|
| 126 |
-
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
|
| 127 |
-
n_q (int): Number of parallel streams to model.
|
| 128 |
-
card (int): Cardinality, vocabulary size.
|
| 129 |
-
dim (int): Dimension of the transformer encoder.
|
| 130 |
-
num_heads (int): Number of heads for the transformer encoder.
|
| 131 |
-
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
|
| 132 |
-
norm (str): Normalization method.
|
| 133 |
-
norm_first (bool): Use pre-norm instead of post-norm.
|
| 134 |
-
emb_lr (float, optional): Embedding-specific learning rate.
|
| 135 |
-
bias_proj (bool): Use bias for output projections.
|
| 136 |
-
weight_init (str, optional): Method for weight initialization.
|
| 137 |
-
depthwise_init (str, optional): Method for depthwise weight initialization.
|
| 138 |
-
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
|
| 139 |
-
cfg_dropout (float): Classifier-free guidance dropout.
|
| 140 |
-
cfg_coef (float): Classifier-free guidance coefficient.
|
| 141 |
-
attribute_dropout (dict): Attribute dropout probabilities.
|
| 142 |
-
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
|
| 143 |
-
**kwargs: Additional parameters for the transformer encoder.
|
| 144 |
-
"""
|
| 145 |
-
def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
|
| 146 |
-
fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
|
| 147 |
-
hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
|
| 148 |
-
emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
|
| 149 |
-
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
|
| 150 |
-
zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
|
| 151 |
-
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
|
| 152 |
-
**kwargs):
|
| 153 |
-
super().__init__()
|
| 154 |
-
self.cfg_coef = cfg_coef
|
| 155 |
-
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
|
| 156 |
-
self.att_dropout = AttributeDropout(p=attribute_dropout)
|
| 157 |
-
self.condition_provider = condition_provider
|
| 158 |
-
self.fuser = fuser
|
| 159 |
-
self.card = card
|
| 160 |
-
embed_dim = self.card + 1
|
| 161 |
-
self.n_q = n_q
|
| 162 |
-
self.dim = dim
|
| 163 |
-
self.pattern_provider = pattern_provider
|
| 164 |
-
self.two_step_cfg = two_step_cfg
|
| 165 |
-
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
|
| 166 |
-
if 'activation' in kwargs:
|
| 167 |
-
kwargs['activation'] = get_activation_fn(kwargs['activation'])
|
| 168 |
-
self.transformer = StreamingTransformer(
|
| 169 |
-
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
|
| 170 |
-
norm=norm, norm_first=norm_first, **kwargs)
|
| 171 |
-
self.out_norm: tp.Optional[nn.Module] = None
|
| 172 |
-
if norm_first:
|
| 173 |
-
self.out_norm = create_norm_fn(norm, dim)
|
| 174 |
-
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
|
| 175 |
-
self._init_weights(weight_init, depthwise_init, zero_bias_init)
|
| 176 |
-
self._fsdp: tp.Optional[nn.Module]
|
| 177 |
-
self.__dict__['_fsdp'] = None
|
| 178 |
-
|
| 179 |
-
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
|
| 180 |
-
"""Initialization of the transformer module weights.
|
| 181 |
-
|
| 182 |
-
Args:
|
| 183 |
-
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
|
| 184 |
-
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
|
| 185 |
-
'current' where the depth corresponds to the current layer index or 'global' where the total number
|
| 186 |
-
of layer is used as depth. If not set, no depthwise initialization strategy is used.
|
| 187 |
-
zero_bias_init (bool): Whether to initialize bias to zero or not.
|
| 188 |
-
"""
|
| 189 |
-
assert depthwise_init is None or depthwise_init in ['current', 'global']
|
| 190 |
-
assert depthwise_init is None or weight_init is not None, \
|
| 191 |
-
"If 'depthwise_init' is defined, a 'weight_init' method should be provided."
|
| 192 |
-
assert not zero_bias_init or weight_init is not None, \
|
| 193 |
-
"If 'zero_bias_init', a 'weight_init' method should be provided"
|
| 194 |
-
|
| 195 |
-
if weight_init is None:
|
| 196 |
-
return
|
| 197 |
-
|
| 198 |
-
for emb_layer in self.emb:
|
| 199 |
-
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
| 200 |
-
|
| 201 |
-
for layer_idx, tr_layer in enumerate(self.transformer.layers):
|
| 202 |
-
depth = None
|
| 203 |
-
if depthwise_init == 'current':
|
| 204 |
-
depth = layer_idx + 1
|
| 205 |
-
elif depthwise_init == 'global':
|
| 206 |
-
depth = len(self.transformer.layers)
|
| 207 |
-
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
|
| 208 |
-
tr_layer.apply(init_fn)
|
| 209 |
-
|
| 210 |
-
for linear in self.linears:
|
| 211 |
-
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
| 212 |
-
|
| 213 |
-
@property
|
| 214 |
-
def special_token_id(self) -> int:
|
| 215 |
-
return self.card
|
| 216 |
-
|
| 217 |
-
@property
|
| 218 |
-
def num_codebooks(self) -> int:
|
| 219 |
-
return self.n_q
|
| 220 |
-
|
| 221 |
-
def forward(self, sequence: torch.Tensor,
|
| 222 |
-
conditions: tp.List[ConditioningAttributes],
|
| 223 |
-
condition_tensors: tp.Optional[ConditionTensors] = None,
|
| 224 |
-
stage: int = -1) -> torch.Tensor:
|
| 225 |
-
"""Apply language model on sequence and conditions.
|
| 226 |
-
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
|
| 227 |
-
S the sequence steps, return the logits with shape [B, card, K, S].
|
| 228 |
-
|
| 229 |
-
Args:
|
| 230 |
-
indices (torch.Tensor): Indices of the codes to model.
|
| 231 |
-
conditions (list of ConditioningAttributes): Conditions to use when modeling
|
| 232 |
-
the given codes. Note that when evaluating multiple time with the same conditioning
|
| 233 |
-
you should pre-compute those and pass them as `condition_tensors`.
|
| 234 |
-
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
|
| 235 |
-
tensors, see `conditions`.
|
| 236 |
-
stage (int): The codebook level that is being predicted. Relevant for MAGNeT
|
| 237 |
-
in which prediction is done in a codebook-by-codebook manner.
|
| 238 |
-
Takes values in range(n_q), and ignored by default.
|
| 239 |
-
Returns:
|
| 240 |
-
torch.Tensor: Logits.
|
| 241 |
-
"""
|
| 242 |
-
B, K, S = sequence.shape
|
| 243 |
-
assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
|
| 244 |
-
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
| 245 |
-
if condition_tensors is None:
|
| 246 |
-
assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
|
| 247 |
-
# apply dropout modules
|
| 248 |
-
conditions = self.cfg_dropout(conditions)
|
| 249 |
-
conditions = self.att_dropout(conditions)
|
| 250 |
-
tokenized = self.condition_provider.tokenize(conditions)
|
| 251 |
-
# encode conditions and fuse, both have a streaming cache to not recompute when generating.
|
| 252 |
-
condition_tensors = self.condition_provider(tokenized)
|
| 253 |
-
else:
|
| 254 |
-
assert not conditions, "Shouldn't pass both conditions and condition_tensors."
|
| 255 |
-
|
| 256 |
-
input_, cross_attention_input = self.fuser(input_, condition_tensors)
|
| 257 |
-
|
| 258 |
-
out = self.transformer(input_, cross_attention_src=cross_attention_input,
|
| 259 |
-
src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None)) # type: ignore
|
| 260 |
-
if self.out_norm:
|
| 261 |
-
out = self.out_norm(out)
|
| 262 |
-
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
|
| 263 |
-
|
| 264 |
-
# remove the prefix from the model outputs
|
| 265 |
-
if len(self.fuser.fuse2cond['prepend']) > 0:
|
| 266 |
-
logits = logits[:, :, -S:]
|
| 267 |
-
|
| 268 |
-
return logits # [B, K, S, card]
|
| 269 |
-
|
| 270 |
-
def compute_predictions(
|
| 271 |
-
self, codes: torch.Tensor,
|
| 272 |
-
conditions: tp.List[ConditioningAttributes],
|
| 273 |
-
condition_tensors: tp.Optional[ConditionTensors] = None,
|
| 274 |
-
stage: int = -1,
|
| 275 |
-
keep_only_valid_steps: bool = True) -> LMOutput:
|
| 276 |
-
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model
|
| 277 |
-
forward using the specified codes interleaving pattern.
|
| 278 |
-
|
| 279 |
-
Args:
|
| 280 |
-
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
|
| 281 |
-
K the number of codebooks and T the number of timesteps.
|
| 282 |
-
conditions (list of ConditioningAttributes): conditionings to use when modeling
|
| 283 |
-
the given codes. Note that when evaluating multiple time with the same conditioning
|
| 284 |
-
you should pre-compute those and pass them as `condition_tensors`.
|
| 285 |
-
condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
|
| 286 |
-
tensors, see `conditions`.
|
| 287 |
-
stage (int): The codebook level that is being predicted. Relevant for MAGNeT
|
| 288 |
-
in which prediction is done in a codebook-by-codebook manner.
|
| 289 |
-
Takes values in range(n_q), and ignored by default.
|
| 290 |
-
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
| 291 |
-
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
| 292 |
-
Returns:
|
| 293 |
-
LMOutput: Language model outputs
|
| 294 |
-
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
|
| 295 |
-
i.e. the first item corresponds to logits to predict the first code, meaning that
|
| 296 |
-
no additional shifting of codes and logits is required.
|
| 297 |
-
mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
|
| 298 |
-
Given the specified interleaving strategies, parts of the logits and codes should
|
| 299 |
-
not be considered as valid predictions because of invalid context.
|
| 300 |
-
"""
|
| 301 |
-
B, K, T = codes.shape
|
| 302 |
-
codes = codes.contiguous()
|
| 303 |
-
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
| 304 |
-
pattern = self.pattern_provider.get_pattern(T)
|
| 305 |
-
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
| 306 |
-
codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
|
| 307 |
-
)
|
| 308 |
-
|
| 309 |
-
# apply model on pattern sequence
|
| 310 |
-
model = self if self._fsdp is None else self._fsdp
|
| 311 |
-
logits = model(sequence_codes, conditions, condition_tensors, stage=stage) # [B, K, S, card]
|
| 312 |
-
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
|
| 313 |
-
# and provide the corresponding mask over invalid positions of tokens
|
| 314 |
-
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
|
| 315 |
-
# note: we use nans as special token to make it obvious if we feed unexpected logits
|
| 316 |
-
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
|
| 317 |
-
logits, float('nan'), keep_only_valid_steps=keep_only_valid_steps
|
| 318 |
-
)
|
| 319 |
-
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
|
| 320 |
-
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
|
| 321 |
-
return LMOutput(logits, logits_mask)
|
| 322 |
-
|
| 323 |
-
def _sample_next_token(self,
|
| 324 |
-
sequence: torch.Tensor,
|
| 325 |
-
cfg_conditions: CFGConditions,
|
| 326 |
-
unconditional_state: State,
|
| 327 |
-
use_sampling: bool = False,
|
| 328 |
-
temp: float = 1.0,
|
| 329 |
-
top_k: int = 0,
|
| 330 |
-
top_p: float = 0.0,
|
| 331 |
-
cfg_coef: tp.Optional[float] = None,
|
| 332 |
-
cfg_coef_beta: tp.Optional[float] = None,
|
| 333 |
-
two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
|
| 334 |
-
"""Sample next token from the model given a sequence and a set of conditions. The model supports
|
| 335 |
-
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
|
| 336 |
-
|
| 337 |
-
Args:
|
| 338 |
-
sequence (torch.Tensor): Current sequence of shape [B, K, S]
|
| 339 |
-
with K corresponding to the number of codebooks and S the number of sequence steps.
|
| 340 |
-
S = 1 in streaming mode, except for the first step that contains a bigger prompt.
|
| 341 |
-
condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
|
| 342 |
-
should be twice the batch size, being the concatenation of the conditions + null conditions.
|
| 343 |
-
use_sampling (bool): Whether to use a sampling strategy or not.
|
| 344 |
-
temp (float): Sampling temperature.
|
| 345 |
-
top_k (int): K for "top-k" sampling.
|
| 346 |
-
top_p (float): P for "top-p" sampling.
|
| 347 |
-
cfg_coef (float, optional): classifier free guidance coefficient
|
| 348 |
-
cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef.
|
| 349 |
-
If not None, we apply double classifier free guidance as introduced in MusicGen-Style
|
| 350 |
-
in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
|
| 351 |
-
push the text condition more than the style condition in the case where both text and style
|
| 352 |
-
conditions are being used.
|
| 353 |
-
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
|
| 354 |
-
|
| 355 |
-
Returns:
|
| 356 |
-
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
|
| 357 |
-
"""
|
| 358 |
-
B = sequence.shape[0]
|
| 359 |
-
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
|
| 360 |
-
model = self if self._fsdp is None else self._fsdp
|
| 361 |
-
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
| 362 |
-
if cfg_coef_beta is not None:
|
| 363 |
-
assert isinstance(cfg_conditions, dict)
|
| 364 |
-
condition_tensors = cfg_conditions
|
| 365 |
-
if condition_tensors:
|
| 366 |
-
# Preparing for CFG, predicting conditional text and style, conditional style
|
| 367 |
-
# and unconditional
|
| 368 |
-
sequence = torch.cat([sequence, sequence, sequence], dim=0)
|
| 369 |
-
all_logits = model(
|
| 370 |
-
sequence,
|
| 371 |
-
conditions=[], condition_tensors=condition_tensors)
|
| 372 |
-
if condition_tensors:
|
| 373 |
-
cond_logits, wav_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
|
| 374 |
-
logits = uncond_logits + cfg_coef * (
|
| 375 |
-
wav_logits + cfg_coef_beta * (cond_logits - wav_logits) - uncond_logits
|
| 376 |
-
)
|
| 377 |
-
|
| 378 |
-
elif two_step_cfg and cfg_conditions != {}:
|
| 379 |
-
assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
|
| 380 |
-
condition_tensors, null_condition_tensors = cfg_conditions
|
| 381 |
-
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
|
| 382 |
-
state = self.get_streaming_state()
|
| 383 |
-
self.set_streaming_state(unconditional_state)
|
| 384 |
-
uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
|
| 385 |
-
unconditional_state.update(self.get_streaming_state())
|
| 386 |
-
self.set_streaming_state(state)
|
| 387 |
-
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
|
| 388 |
-
else:
|
| 389 |
-
assert isinstance(cfg_conditions, dict)
|
| 390 |
-
condition_tensors = cfg_conditions
|
| 391 |
-
if condition_tensors:
|
| 392 |
-
# Preparing for CFG, predicting both conditional and unconditional logits.
|
| 393 |
-
sequence = torch.cat([sequence, sequence], dim=0)
|
| 394 |
-
all_logits = model(
|
| 395 |
-
sequence,
|
| 396 |
-
conditions=[], condition_tensors=condition_tensors)
|
| 397 |
-
if condition_tensors:
|
| 398 |
-
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
|
| 399 |
-
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
|
| 400 |
-
else:
|
| 401 |
-
logits = all_logits
|
| 402 |
-
|
| 403 |
-
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
| 404 |
-
logits = logits[..., -1] # [B x K x card]
|
| 405 |
-
|
| 406 |
-
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
| 407 |
-
if use_sampling and temp > 0.0:
|
| 408 |
-
probs = torch.softmax(logits / temp, dim=-1)
|
| 409 |
-
if top_p > 0.0:
|
| 410 |
-
next_token = utils.sample_top_p(probs, p=top_p)
|
| 411 |
-
elif top_k > 0:
|
| 412 |
-
next_token = utils.sample_top_k(probs, k=top_k)
|
| 413 |
-
else:
|
| 414 |
-
next_token = utils.multinomial(probs, num_samples=1)
|
| 415 |
-
else:
|
| 416 |
-
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
| 417 |
-
|
| 418 |
-
return next_token
|
| 419 |
-
|
| 420 |
-
@torch.no_grad()
|
| 421 |
-
def generate(self,
|
| 422 |
-
prompt: tp.Optional[torch.Tensor] = None,
|
| 423 |
-
conditions: tp.List[ConditioningAttributes] = [],
|
| 424 |
-
num_samples: tp.Optional[int] = None,
|
| 425 |
-
max_gen_len: int = 256,
|
| 426 |
-
use_sampling: bool = True,
|
| 427 |
-
temp: float = 1.0,
|
| 428 |
-
top_k: int = 250,
|
| 429 |
-
top_p: float = 0.0,
|
| 430 |
-
cfg_coef: tp.Optional[float] = None,
|
| 431 |
-
cfg_coef_beta: tp.Optional[float] = None,
|
| 432 |
-
two_step_cfg: tp.Optional[bool] = None,
|
| 433 |
-
remove_prompts: bool = False,
|
| 434 |
-
check: bool = False,
|
| 435 |
-
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
| 436 |
-
) -> torch.Tensor:
|
| 437 |
-
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
|
| 438 |
-
be performed in a greedy fashion or using sampling with top K and top P strategies.
|
| 439 |
-
|
| 440 |
-
Args:
|
| 441 |
-
prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
|
| 442 |
-
conditions (list of ConditioningAttributes, optional): List of conditions.
|
| 443 |
-
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
|
| 444 |
-
max_gen_len (int): Maximum generation length.
|
| 445 |
-
use_sampling (bool): Whether to use a sampling strategy or not.
|
| 446 |
-
temp (float): Sampling temperature.
|
| 447 |
-
top_k (int): K for "top-k" sampling.
|
| 448 |
-
top_p (float): P for "top-p" sampling.
|
| 449 |
-
cfg_coef (float, optional): Classifier-free guidance coefficient.
|
| 450 |
-
cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef.
|
| 451 |
-
If not None, we apply double classifier free guidance as introduced in MusicGen-Style
|
| 452 |
-
in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
|
| 453 |
-
push the text condition more than the style condition in the case where both text and style
|
| 454 |
-
conditions are being used.
|
| 455 |
-
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
|
| 456 |
-
remove_prompts (bool): Whether to remove prompts from generation or not.
|
| 457 |
-
check (bool): Whether to apply further checks on generated sequence.
|
| 458 |
-
callback (Callback, optional): Callback function to report generation progress.
|
| 459 |
-
Returns:
|
| 460 |
-
torch.Tensor: Generated tokens.
|
| 461 |
-
"""
|
| 462 |
-
assert not self.training, "generation shouldn't be used in training mode."
|
| 463 |
-
first_param = next(iter(self.parameters()))
|
| 464 |
-
device = first_param.device
|
| 465 |
-
|
| 466 |
-
# Checking all input shapes are consistent.
|
| 467 |
-
possible_num_samples = []
|
| 468 |
-
if num_samples is not None:
|
| 469 |
-
possible_num_samples.append(num_samples)
|
| 470 |
-
elif prompt is not None:
|
| 471 |
-
possible_num_samples.append(prompt.shape[0])
|
| 472 |
-
elif conditions:
|
| 473 |
-
possible_num_samples.append(len(conditions))
|
| 474 |
-
else:
|
| 475 |
-
possible_num_samples.append(1)
|
| 476 |
-
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
|
| 477 |
-
num_samples = possible_num_samples[0]
|
| 478 |
-
|
| 479 |
-
# below we create set of conditions: one conditional and one unconditional
|
| 480 |
-
# to do that we merge the regular condition together with the null condition
|
| 481 |
-
# we then do 1 forward pass instead of 2.
|
| 482 |
-
# the reason for that is two-fold:
|
| 483 |
-
# 1. it is about x2 faster than doing 2 forward passes
|
| 484 |
-
# 2. avoid the streaming API treating the 2 passes as part of different time steps
|
| 485 |
-
# We also support doing two different passes, in particular to ensure that
|
| 486 |
-
# the padding structure is exactly the same between train and test.
|
| 487 |
-
# With a batch size of 1, this can be slower though.
|
| 488 |
-
cfg_conditions: CFGConditions
|
| 489 |
-
cfg_conditions = {}
|
| 490 |
-
if cfg_coef_beta is not None:
|
| 491 |
-
if conditions:
|
| 492 |
-
wav_conditions = _drop_description_condition(conditions)
|
| 493 |
-
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
|
| 494 |
-
conditions = conditions + wav_conditions + null_conditions
|
| 495 |
-
tokenized = self.condition_provider.tokenize(conditions)
|
| 496 |
-
cfg_conditions = self.condition_provider(tokenized)
|
| 497 |
-
elif conditions:
|
| 498 |
-
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
| 499 |
-
if conditions:
|
| 500 |
-
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
|
| 501 |
-
if two_step_cfg:
|
| 502 |
-
cfg_conditions = (
|
| 503 |
-
self.condition_provider(self.condition_provider.tokenize(conditions)),
|
| 504 |
-
self.condition_provider(self.condition_provider.tokenize(null_conditions)),
|
| 505 |
-
)
|
| 506 |
-
else:
|
| 507 |
-
conditions = conditions + null_conditions
|
| 508 |
-
tokenized = self.condition_provider.tokenize(conditions)
|
| 509 |
-
cfg_conditions = self.condition_provider(tokenized)
|
| 510 |
-
else:
|
| 511 |
-
cfg_conditions = {}
|
| 512 |
-
|
| 513 |
-
if prompt is None:
|
| 514 |
-
assert num_samples > 0
|
| 515 |
-
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
|
| 516 |
-
|
| 517 |
-
B, K, T = prompt.shape
|
| 518 |
-
start_offset = T
|
| 519 |
-
print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
|
| 520 |
-
assert start_offset <= max_gen_len
|
| 521 |
-
|
| 522 |
-
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
| 523 |
-
# this token is used as default value for codes that are not generated yet
|
| 524 |
-
unknown_token = -1
|
| 525 |
-
|
| 526 |
-
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence
|
| 527 |
-
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
|
| 528 |
-
# filling the gen_codes with the prompt if needed
|
| 529 |
-
gen_codes[..., :start_offset] = prompt
|
| 530 |
-
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
| 531 |
-
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
| 532 |
-
# retrieve the start_offset in the sequence:
|
| 533 |
-
# it is the first sequence step that contains the `start_offset` timestep
|
| 534 |
-
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
|
| 535 |
-
assert start_offset_sequence is not None
|
| 536 |
-
|
| 537 |
-
with self.streaming():
|
| 538 |
-
unconditional_state = self.get_streaming_state()
|
| 539 |
-
prev_offset = 0
|
| 540 |
-
gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
|
| 541 |
-
for offset in range(start_offset_sequence, gen_sequence_len):
|
| 542 |
-
# get current sequence (note that the streaming API is providing the caching over previous offsets)
|
| 543 |
-
curr_sequence = gen_sequence[..., prev_offset:offset]
|
| 544 |
-
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
|
| 545 |
-
if check:
|
| 546 |
-
# check coherence between mask and sequence
|
| 547 |
-
assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
|
| 548 |
-
# should never happen as gen_sequence is filled progressively
|
| 549 |
-
assert not (curr_sequence == unknown_token).any()
|
| 550 |
-
# sample next token from the model, next token shape is [B, K, 1]
|
| 551 |
-
next_token = self._sample_next_token(
|
| 552 |
-
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
|
| 553 |
-
cfg_coef=cfg_coef, cfg_coef_beta=cfg_coef_beta, two_step_cfg=two_step_cfg)
|
| 554 |
-
# ensure the tokens that should be masked are properly set to special_token_id
|
| 555 |
-
# as the model never output special_token_id
|
| 556 |
-
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
| 557 |
-
next_token[~valid_mask] = self.special_token_id
|
| 558 |
-
# ensure we don't overwrite prompt tokens, we only write over unknown tokens
|
| 559 |
-
# (then mask tokens should be left as is as well, which is correct)
|
| 560 |
-
gen_sequence[..., offset:offset+1] = torch.where(
|
| 561 |
-
gen_sequence[..., offset:offset+1] == unknown_token,
|
| 562 |
-
next_token, gen_sequence[..., offset:offset+1]
|
| 563 |
-
)
|
| 564 |
-
prev_offset = offset
|
| 565 |
-
if callback is not None:
|
| 566 |
-
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
| 567 |
-
unconditional_state.clear()
|
| 568 |
-
|
| 569 |
-
# ensure sequence has been entirely filled
|
| 570 |
-
assert not (gen_sequence == unknown_token).any()
|
| 571 |
-
# ensure gen_sequence pattern and mask are matching
|
| 572 |
-
# which means the gen_sequence is valid according to the pattern
|
| 573 |
-
assert (
|
| 574 |
-
gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
|
| 575 |
-
).all()
|
| 576 |
-
# get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
|
| 577 |
-
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
| 578 |
-
|
| 579 |
-
# sanity checks over the returned codes and corresponding masks
|
| 580 |
-
assert (out_codes[..., :max_gen_len] != unknown_token).all()
|
| 581 |
-
assert (out_mask[..., :max_gen_len] == 1).all()
|
| 582 |
-
|
| 583 |
-
out_start_offset = start_offset if remove_prompts else 0
|
| 584 |
-
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
| 585 |
-
|
| 586 |
-
# ensure the returned codes are all valid
|
| 587 |
-
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
| 588 |
-
return out_codes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/models/lm_magnet.py
DELETED
|
@@ -1,500 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import logging
|
| 8 |
-
import math
|
| 9 |
-
import typing as tp
|
| 10 |
-
import torch
|
| 11 |
-
import numpy as np
|
| 12 |
-
|
| 13 |
-
from ..utils import utils
|
| 14 |
-
from ..modules.conditioners import (
|
| 15 |
-
ClassifierFreeGuidanceDropout,
|
| 16 |
-
ConditioningAttributes,
|
| 17 |
-
ConditionType,
|
| 18 |
-
)
|
| 19 |
-
from .lm import LMModel
|
| 20 |
-
|
| 21 |
-
logger = logging.getLogger(__name__)
|
| 22 |
-
ConditionTensors = tp.Dict[str, ConditionType]
|
| 23 |
-
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class MagnetLMModel(LMModel):
|
| 27 |
-
"""Transformer-based, non-autoregressive model, operates on multiple streams of audio tokens (MAGNeT).
|
| 28 |
-
Args:
|
| 29 |
-
subcodes_context (int): The number of timesteps attended in the self-attention blocks of codebooks > 0.
|
| 30 |
-
When set to -1, attention is unrestricted and all timesteps are attended. Defaults to 5.
|
| 31 |
-
compression_model_framerate (int): frame rate of the audio tokenizer.
|
| 32 |
-
segment_duration (int): Sample length in seconds.
|
| 33 |
-
span_len (int): Determines the length of masking spans. This is the minimal length of consecutive masked tokens,
|
| 34 |
-
for both training and inference. Defaults to 3.
|
| 35 |
-
**kwargs: Additional parameters for the LMModel.
|
| 36 |
-
"""
|
| 37 |
-
def __init__(self, subcodes_context: int = 5, compression_model_framerate: int = 50,
|
| 38 |
-
segment_duration: int = 10, span_len: int = 3, **kwargs):
|
| 39 |
-
super().__init__(**kwargs)
|
| 40 |
-
self.causal = kwargs['causal']
|
| 41 |
-
self.subcodes_context = subcodes_context
|
| 42 |
-
self.span_len = span_len
|
| 43 |
-
self._build_attn_masks(compression_model_framerate=compression_model_framerate,
|
| 44 |
-
segment_duration=segment_duration,
|
| 45 |
-
num_heads=kwargs['num_heads'],
|
| 46 |
-
device=kwargs['device'], dtype=kwargs['dtype'])
|
| 47 |
-
|
| 48 |
-
def restricted_context_attn_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
| 49 |
-
"""Creates a restricted attention mask (local attention map) where the context
|
| 50 |
-
is determined by self.subcodes_context.
|
| 51 |
-
Args:
|
| 52 |
-
seq_len (int): token sequence length.
|
| 53 |
-
device (torch.device): device of the output tensor.
|
| 54 |
-
dtype (torch.dtype): data type of the output tensor.
|
| 55 |
-
Returns:
|
| 56 |
-
torch.Tensor: The restricted attention mask.
|
| 57 |
-
"""
|
| 58 |
-
# Return a context restricted non-causal att mask
|
| 59 |
-
queries_pos = torch.arange(seq_len, device=device).view(-1, 1)
|
| 60 |
-
keys_pos = torch.arange(seq_len, device=device).view(1, -1)
|
| 61 |
-
|
| 62 |
-
delta = queries_pos - keys_pos
|
| 63 |
-
valid = torch.abs(delta) <= self.subcodes_context
|
| 64 |
-
return torch.where(
|
| 65 |
-
valid,
|
| 66 |
-
torch.zeros([], device=device, dtype=dtype),
|
| 67 |
-
torch.full([], float('-inf'), device=device, dtype=dtype))
|
| 68 |
-
|
| 69 |
-
def _stage_attn_mask(self, stage: int, seq_len: int, num_heads: int,
|
| 70 |
-
device: torch.device, dtype: torch.dtype) -> tp.Optional[torch.Tensor]:
|
| 71 |
-
"""Creates a restricted attention mask given the stage (codebook index).
|
| 72 |
-
Args:
|
| 73 |
-
stage (int): The codebook index. Takes values in [0, n_q].
|
| 74 |
-
seq_len (int): Token sequence length.
|
| 75 |
-
num_heads (int): Num transformer attention heads.
|
| 76 |
-
device (torch.device): device of the output tensor.
|
| 77 |
-
dtype (torch.dtype): data type of the output tensor.
|
| 78 |
-
Returns:
|
| 79 |
-
torch.Tensor: Either a restricted attention mask or None if stage attention is unrestricted.
|
| 80 |
-
"""
|
| 81 |
-
sa_mask = None
|
| 82 |
-
|
| 83 |
-
if stage > 0 and self.subcodes_context > -1:
|
| 84 |
-
# parallel - non-causal - with restricted subcodes context
|
| 85 |
-
sa_mask = self.restricted_context_attn_mask(seq_len, device=device, dtype=dtype)
|
| 86 |
-
|
| 87 |
-
if sa_mask is not None:
|
| 88 |
-
# Repeat for each attention head
|
| 89 |
-
sa_mask = sa_mask.repeat((1, num_heads, 1, 1))
|
| 90 |
-
|
| 91 |
-
# align8 to enable memory efficient attention
|
| 92 |
-
MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR = 8
|
| 93 |
-
seq_len_aligned = \
|
| 94 |
-
int(np.ceil(seq_len / MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR)) * MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR
|
| 95 |
-
|
| 96 |
-
sa_mask_aligned = torch.zeros((1, num_heads, seq_len_aligned, seq_len_aligned), device=device, dtype=dtype)
|
| 97 |
-
sa_mask_aligned[..., :seq_len, :seq_len] = sa_mask
|
| 98 |
-
sa_mask = sa_mask_aligned
|
| 99 |
-
|
| 100 |
-
return sa_mask
|
| 101 |
-
|
| 102 |
-
def _build_attn_masks(self, compression_model_framerate: int, segment_duration: int, num_heads: int,
|
| 103 |
-
device: torch.device, dtype: torch.dtype):
|
| 104 |
-
"""Construct attention mask per stage. For each of the RVQ codebook levels in the [0, n_q] range,
|
| 105 |
-
either a local attention map or None would be stored as an entry in the self.attn_mask_per_stage list.
|
| 106 |
-
Args:
|
| 107 |
-
compression_model_framerate (int): The frame rate of the tokenizer.
|
| 108 |
-
segment_duration (int): Sample length in seconds.
|
| 109 |
-
num_heads (int): Num transformer attention heads.
|
| 110 |
-
device (torch.device): device of the output tensor.
|
| 111 |
-
dtype (torch.dtype): data type of the output tensor.
|
| 112 |
-
"""
|
| 113 |
-
seq_len = compression_model_framerate * segment_duration
|
| 114 |
-
self.attn_mask_per_stage = [self._stage_attn_mask(stage, seq_len, num_heads,
|
| 115 |
-
device, dtype) for stage in range(self.n_q)]
|
| 116 |
-
|
| 117 |
-
@torch.no_grad()
|
| 118 |
-
def generate(self,
|
| 119 |
-
prompt: tp.Optional[torch.Tensor] = None,
|
| 120 |
-
conditions: tp.List[ConditioningAttributes] = [],
|
| 121 |
-
num_samples: tp.Optional[int] = None,
|
| 122 |
-
max_gen_len: int = 256,
|
| 123 |
-
use_sampling: bool = True,
|
| 124 |
-
temp: float = 1.0,
|
| 125 |
-
top_k: int = 250,
|
| 126 |
-
top_p: float = 0.0,
|
| 127 |
-
cfg_coef: tp.Optional[float] = None,
|
| 128 |
-
cfg_coef_beta: tp.Optional[float] = None,
|
| 129 |
-
two_step_cfg: tp.Optional[bool] = None,
|
| 130 |
-
remove_prompts: bool = False,
|
| 131 |
-
check: bool = False,
|
| 132 |
-
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
| 133 |
-
**kwargs) -> torch.Tensor:
|
| 134 |
-
|
| 135 |
-
assert cfg_coef is None, "Unsupported in MAGNeT. Use max_cfg_coef,min_cfg_coef instead."
|
| 136 |
-
assert two_step_cfg is None, "MAGNeT currently doesn't support two step classifier-free-guidance."
|
| 137 |
-
assert remove_prompts is False, "MAGNeT currently doesn't support the remove_prompts arg."
|
| 138 |
-
assert check is False, "MAGNeT currently doesn't support the check arg."
|
| 139 |
-
assert cfg_coef_beta is None, "MAGNeT currently doesn't support the cfg_coef_beta arg."
|
| 140 |
-
# Call the MAGNeT-specific generation method
|
| 141 |
-
return self._generate_magnet(prompt=prompt,
|
| 142 |
-
conditions=conditions,
|
| 143 |
-
num_samples=num_samples,
|
| 144 |
-
max_gen_len=max_gen_len,
|
| 145 |
-
use_sampling=use_sampling,
|
| 146 |
-
temp=temp,
|
| 147 |
-
top_k=top_k,
|
| 148 |
-
top_p=top_p,
|
| 149 |
-
callback=callback, **kwargs)
|
| 150 |
-
|
| 151 |
-
@torch.no_grad()
|
| 152 |
-
def _generate_magnet(self,
|
| 153 |
-
prompt: tp.Optional[torch.Tensor] = None,
|
| 154 |
-
conditions: tp.List[ConditioningAttributes] = [],
|
| 155 |
-
num_samples: tp.Optional[int] = None,
|
| 156 |
-
max_gen_len: int = 256,
|
| 157 |
-
use_sampling: bool = True,
|
| 158 |
-
temp: float = 3.0,
|
| 159 |
-
top_k: int = 0,
|
| 160 |
-
top_p: float = 0.9,
|
| 161 |
-
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
| 162 |
-
max_cfg_coef: float = 10.0,
|
| 163 |
-
min_cfg_coef: float = 1.0,
|
| 164 |
-
decoding_steps: tp.List[int] = [20, 10, 10, 10],
|
| 165 |
-
anneal_temp: bool = True,
|
| 166 |
-
span_scoring='max',
|
| 167 |
-
span_arrangement='nonoverlap') -> torch.Tensor:
|
| 168 |
-
"""Generate audio tokens given textual conditions, and optionally given audio prompts,
|
| 169 |
-
by running MAGNeT's iterative decoding algorithm for each of the n_q RVQ levels.
|
| 170 |
-
Args:
|
| 171 |
-
prompt (torch.Tensor): Prompt tokens of shape [B, K, T].
|
| 172 |
-
conditions (list of ConditioningAttributes): List of conditions.
|
| 173 |
-
num_samples (int): Number of samples to generate when no prompt and no conditions are given.
|
| 174 |
-
max_gen_len (int): Maximum generation length.
|
| 175 |
-
use_sampling (bool): Whether to use a sampling strategy or not.
|
| 176 |
-
temp (float): Initial sampling temperature.
|
| 177 |
-
top_k (int): k for "top-k" sampling.
|
| 178 |
-
top_p (float): p for "top-p" sampling.
|
| 179 |
-
callback (Callback): Callback function to report generation progress.
|
| 180 |
-
max_clsfg_coef (float): Initial coefficient used for classifier free guidance.
|
| 181 |
-
min_clsfg_coef (float): Final coefficient used for classifier free guidance.
|
| 182 |
-
decoding_steps (list of n_q ints): The number of iterative decoding steps,
|
| 183 |
-
for each of the n_q RVQ codebooks.
|
| 184 |
-
anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage.
|
| 185 |
-
span_scoring (str): Use the maximum probability of each span ('max')
|
| 186 |
-
or the product of probabilities ('prod').
|
| 187 |
-
span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1').
|
| 188 |
-
in the masking scheme.
|
| 189 |
-
Returns:
|
| 190 |
-
torch.Tensor: Generated tokens.
|
| 191 |
-
"""
|
| 192 |
-
assert not self.training, "generation shouldn't be used in training mode."
|
| 193 |
-
first_param = next(iter(self.parameters()))
|
| 194 |
-
device = first_param.device
|
| 195 |
-
|
| 196 |
-
# Checking all input shapes are consistent.
|
| 197 |
-
possible_num_samples = []
|
| 198 |
-
if num_samples is not None:
|
| 199 |
-
possible_num_samples.append(num_samples)
|
| 200 |
-
elif prompt is not None:
|
| 201 |
-
possible_num_samples.append(prompt.shape[0])
|
| 202 |
-
elif conditions:
|
| 203 |
-
possible_num_samples.append(len(conditions))
|
| 204 |
-
else:
|
| 205 |
-
possible_num_samples.append(1)
|
| 206 |
-
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
|
| 207 |
-
num_samples = possible_num_samples[0]
|
| 208 |
-
|
| 209 |
-
# below we create set of conditions: one conditional and one unconditional
|
| 210 |
-
# to do that we merge the regular condition together with the null condition
|
| 211 |
-
# we then do 1 forward pass instead of 2.
|
| 212 |
-
cfg_conditions: tp.Optional[ConditionTensors]
|
| 213 |
-
if conditions:
|
| 214 |
-
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
|
| 215 |
-
conditions = conditions + null_conditions
|
| 216 |
-
tokenized = self.condition_provider.tokenize(conditions)
|
| 217 |
-
cfg_conditions = self.condition_provider(tokenized)
|
| 218 |
-
else:
|
| 219 |
-
cfg_conditions = {}
|
| 220 |
-
|
| 221 |
-
if prompt is None:
|
| 222 |
-
assert num_samples > 0
|
| 223 |
-
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
|
| 224 |
-
|
| 225 |
-
B, K, prompt_length = prompt.shape
|
| 226 |
-
start_offset = prompt_length
|
| 227 |
-
assert start_offset < max_gen_len
|
| 228 |
-
|
| 229 |
-
mask_id = self.special_token_id
|
| 230 |
-
|
| 231 |
-
# we generate codes with a fixed sequence length
|
| 232 |
-
shape = (B, K, max_gen_len)
|
| 233 |
-
|
| 234 |
-
gen_codes = torch.full(shape, mask_id, dtype=torch.long, device=device)
|
| 235 |
-
# filling the gen_codes with the prompt if needed
|
| 236 |
-
gen_codes[..., :start_offset] = prompt
|
| 237 |
-
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
| 238 |
-
gen_sequence = gen_codes
|
| 239 |
-
|
| 240 |
-
curr_step = 0
|
| 241 |
-
for stage, n_steps in zip(range(self.n_q), decoding_steps):
|
| 242 |
-
gen_sequence, curr_step = self._generate_stage(gen_sequence,
|
| 243 |
-
cfg_conditions,
|
| 244 |
-
stage=stage,
|
| 245 |
-
device=device,
|
| 246 |
-
prompt_length=prompt_length,
|
| 247 |
-
prompt=prompt,
|
| 248 |
-
temp=temp,
|
| 249 |
-
max_cfg_coef=max_cfg_coef,
|
| 250 |
-
min_cfg_coef=min_cfg_coef,
|
| 251 |
-
top_k=top_k,
|
| 252 |
-
top_p=top_p,
|
| 253 |
-
timesteps=n_steps,
|
| 254 |
-
anneal_temp=anneal_temp,
|
| 255 |
-
span_scoring=span_scoring,
|
| 256 |
-
use_sampling=use_sampling,
|
| 257 |
-
span_arrangement=span_arrangement,
|
| 258 |
-
curr_step=curr_step,
|
| 259 |
-
total_steps=sum(decoding_steps),
|
| 260 |
-
callback=callback)
|
| 261 |
-
|
| 262 |
-
return gen_sequence
|
| 263 |
-
|
| 264 |
-
@torch.no_grad()
|
| 265 |
-
def _generate_stage(self,
|
| 266 |
-
gen_sequence: torch.Tensor,
|
| 267 |
-
condition_tensors: tp.Optional[ConditionTensors],
|
| 268 |
-
stage: int,
|
| 269 |
-
device: torch.device,
|
| 270 |
-
prompt_length: int = 0,
|
| 271 |
-
prompt: tp.Optional[torch.Tensor] = None,
|
| 272 |
-
use_sampling: bool = True,
|
| 273 |
-
temp: float = 3.0,
|
| 274 |
-
max_cfg_coef: float = 10.0,
|
| 275 |
-
min_cfg_coef: float = 1.0,
|
| 276 |
-
top_k: int = 0,
|
| 277 |
-
top_p: float = 0.0,
|
| 278 |
-
timesteps: int = 10,
|
| 279 |
-
anneal_temp: bool = True,
|
| 280 |
-
span_scoring: str = 'max',
|
| 281 |
-
span_arrangement: str = 'nonoverlap',
|
| 282 |
-
curr_step: int = 0,
|
| 283 |
-
total_steps: int = 0,
|
| 284 |
-
callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> tp.Tuple[torch.Tensor, int]:
|
| 285 |
-
"""Generate audio tokens of a single RVQ level (stage), given the previously generated stages,
|
| 286 |
-
and the textual conditions.
|
| 287 |
-
Args:
|
| 288 |
-
gen_sequence (torch.Tensor): Previously generated tokens.
|
| 289 |
-
condition_tensors (tp.Optional[ConditionTensors]): pre-computed conditioning tensors.
|
| 290 |
-
stage (int): RVQ level to generate.
|
| 291 |
-
device (torch.device): device of the output tensor.
|
| 292 |
-
prompt_length (int): Temporal length of the audio prompt.
|
| 293 |
-
prompt (torch.Tensor): Prompt tokens of shape [B, K, T].
|
| 294 |
-
use_sampling (bool): Whether to use a sampling strategy or not.
|
| 295 |
-
temp (float): Initial sampling temperature.
|
| 296 |
-
max_clsfg_coef (float): Initial coefficient used for classifier free guidance.
|
| 297 |
-
min_clsfg_coef (float): Final coefficient used for classifier free guidance.
|
| 298 |
-
top_k (int): k for "top-k" sampling.
|
| 299 |
-
top_p (float): p for "top-p" sampling.
|
| 300 |
-
timesteps (int): Number of iterative decoding steps.
|
| 301 |
-
anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage.
|
| 302 |
-
span_scoring (str): Use the maximum probability of each span ('max')
|
| 303 |
-
or the product of probabilities ('prod').
|
| 304 |
-
span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1').
|
| 305 |
-
in the masking scheme.
|
| 306 |
-
curr_step (int): Global iterative decoding step counter.
|
| 307 |
-
total_steps (int): Total decoding steps.
|
| 308 |
-
callback (Callback): Callback function to report generation progress.
|
| 309 |
-
Returns:
|
| 310 |
-
tuple(torch.Tensor, int): Generated tokens and the current decoding step counter.
|
| 311 |
-
"""
|
| 312 |
-
B, K, T = gen_sequence.shape
|
| 313 |
-
shape = (B, 1, T) # generating a single codebook per stage
|
| 314 |
-
|
| 315 |
-
mask_id = self.special_token_id
|
| 316 |
-
stage_gen_seq = torch.full(shape, mask_id, dtype=torch.long, device=device)
|
| 317 |
-
|
| 318 |
-
assert span_arrangement == 'nonoverlap' or span_arrangement == 'stride1'
|
| 319 |
-
chunk_masking = self.span_len > 1 and span_arrangement == 'nonoverlap'
|
| 320 |
-
|
| 321 |
-
DONT_REMASK_ME_SCORE = -1e4
|
| 322 |
-
|
| 323 |
-
model = self if self._fsdp is None else self._fsdp
|
| 324 |
-
|
| 325 |
-
if chunk_masking:
|
| 326 |
-
# span-wise scores
|
| 327 |
-
n_chunks = T // self.span_len
|
| 328 |
-
if T % self.span_len != 0:
|
| 329 |
-
# trim sequence ending to achieve a multiple of span_len
|
| 330 |
-
T = self.span_len * n_chunks
|
| 331 |
-
gen_sequence = gen_sequence[..., :T]
|
| 332 |
-
stage_gen_seq = stage_gen_seq[..., :T]
|
| 333 |
-
|
| 334 |
-
chunked_shape = (B, 1, n_chunks)
|
| 335 |
-
n_prompt_chunks = prompt_length // self.span_len
|
| 336 |
-
scores = torch.zeros(chunked_shape, dtype=torch.float32, device=device)
|
| 337 |
-
scores[..., :n_prompt_chunks] = DONT_REMASK_ME_SCORE
|
| 338 |
-
num_chunks_to_gen = n_chunks - n_prompt_chunks
|
| 339 |
-
else:
|
| 340 |
-
# token-wise scores
|
| 341 |
-
scores = torch.zeros(shape, dtype=torch.float32, device=device)
|
| 342 |
-
scores[..., :prompt_length] = DONT_REMASK_ME_SCORE
|
| 343 |
-
gen_T = T - prompt_length
|
| 344 |
-
|
| 345 |
-
# run MAGNeT iterative decoding for "timesteps" iterations
|
| 346 |
-
for timestep, steps_left in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
|
| 347 |
-
|
| 348 |
-
mask_p = torch.cos(timestep * math.pi * 0.5)
|
| 349 |
-
|
| 350 |
-
if chunk_masking:
|
| 351 |
-
num_masked = max(int((mask_p * num_chunks_to_gen).item()), 1)
|
| 352 |
-
else:
|
| 353 |
-
num_masked = max(int((mask_p * gen_T).item()), 1)
|
| 354 |
-
|
| 355 |
-
# masking
|
| 356 |
-
run_lps_masking = (span_arrangement == 'stride1') and self.span_len > 1
|
| 357 |
-
if run_lps_masking:
|
| 358 |
-
# masking of the k least probable overlapping (stride 1) spans
|
| 359 |
-
mask = torch.concat((
|
| 360 |
-
[self._least_probable_span_masking(scores[[i], :, :], num_masked).to(device)
|
| 361 |
-
for i in range(B)]), dim=0)
|
| 362 |
-
stage_gen_seq[mask] = mask_id
|
| 363 |
-
else:
|
| 364 |
-
# masking of the k least probable non-overlapping spans
|
| 365 |
-
masked = scores.topk(num_masked, dim=-1).indices
|
| 366 |
-
if chunk_masking:
|
| 367 |
-
chunks_mask = torch.full(chunked_shape, False, dtype=torch.bool, device=device)
|
| 368 |
-
chunks_mask = chunks_mask.scatter(2, masked, True)
|
| 369 |
-
mask = torch.repeat_interleave(chunks_mask, self.span_len, dim=-1)
|
| 370 |
-
stage_gen_seq[mask] = mask_id
|
| 371 |
-
else:
|
| 372 |
-
stage_gen_seq = stage_gen_seq.scatter(2, masked, mask_id)
|
| 373 |
-
|
| 374 |
-
if prompt is not None:
|
| 375 |
-
stage_gen_seq[..., :prompt_length] = prompt[:, stage, :].unsqueeze(1)
|
| 376 |
-
|
| 377 |
-
gen_sequence[:, [stage], :] = stage_gen_seq
|
| 378 |
-
if condition_tensors:
|
| 379 |
-
# duplicate input for classifier free guidance
|
| 380 |
-
sequence = torch.cat([gen_sequence, gen_sequence], dim=0)
|
| 381 |
-
|
| 382 |
-
all_logits = model(sequence, [], condition_tensors, stage=stage)
|
| 383 |
-
|
| 384 |
-
if condition_tensors:
|
| 385 |
-
# classifier free guidance with annealing
|
| 386 |
-
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
|
| 387 |
-
clsfg_coef = float(mask_p) * max_cfg_coef + (1 - float(mask_p)) * min_cfg_coef
|
| 388 |
-
logits = uncond_logits + (cond_logits - uncond_logits) * clsfg_coef
|
| 389 |
-
else:
|
| 390 |
-
logits = all_logits
|
| 391 |
-
|
| 392 |
-
# temperature annealing - linear
|
| 393 |
-
t = temp * (steps_left / timesteps) if anneal_temp else temp
|
| 394 |
-
|
| 395 |
-
# sampling
|
| 396 |
-
logits = logits[:, stage, :, :].unsqueeze(1)
|
| 397 |
-
probs = torch.softmax(logits / max(t, 1e-2), dim=-1)
|
| 398 |
-
if use_sampling:
|
| 399 |
-
if top_p > 0.0:
|
| 400 |
-
sampled_tokens = utils.sample_top_p(probs, p=top_p)
|
| 401 |
-
elif top_k > 0:
|
| 402 |
-
sampled_tokens = utils.sample_top_k(probs, k=top_k)
|
| 403 |
-
else:
|
| 404 |
-
sampled_tokens = utils.multinomial(probs, num_samples=1)
|
| 405 |
-
else:
|
| 406 |
-
sampled_tokens = torch.argmax(logits, dim=-1, keepdim=True)
|
| 407 |
-
|
| 408 |
-
# place mask_id token in each of the masked positions
|
| 409 |
-
mask = stage_gen_seq == mask_id
|
| 410 |
-
stage_gen_seq = torch.where(mask, sampled_tokens[..., 0], stage_gen_seq)
|
| 411 |
-
gen_sequence[:, [stage], :] = stage_gen_seq
|
| 412 |
-
|
| 413 |
-
# get probs of sampled tokens
|
| 414 |
-
sampled_probs = torch.gather(probs, 3, sampled_tokens)[..., 0]
|
| 415 |
-
|
| 416 |
-
# span scoring
|
| 417 |
-
if chunk_masking:
|
| 418 |
-
if span_scoring == 'max':
|
| 419 |
-
# max in linear space
|
| 420 |
-
scores = 1 - torch.max(sampled_probs.reshape((B, 1, n_chunks, -1)), dim=-1)[0]
|
| 421 |
-
elif span_scoring == 'prod':
|
| 422 |
-
# prod in log space
|
| 423 |
-
scores = torch.sum(-torch.log(sampled_probs).reshape((B, 1, n_chunks, -1)), dim=-1)
|
| 424 |
-
else:
|
| 425 |
-
raise NotImplementedError
|
| 426 |
-
else:
|
| 427 |
-
# prod in log space for lps masking (stride1)
|
| 428 |
-
scores = -torch.log(sampled_probs)
|
| 429 |
-
|
| 430 |
-
# Fix unmasked tokens by placing inf probs (-inf scores)
|
| 431 |
-
if chunk_masking:
|
| 432 |
-
scores = scores.masked_fill(~chunks_mask, DONT_REMASK_ME_SCORE)
|
| 433 |
-
else:
|
| 434 |
-
scores = scores.masked_fill(~mask, DONT_REMASK_ME_SCORE)
|
| 435 |
-
|
| 436 |
-
if callback is not None:
|
| 437 |
-
curr_step += 1
|
| 438 |
-
callback(curr_step, total_steps)
|
| 439 |
-
|
| 440 |
-
return gen_sequence, curr_step
|
| 441 |
-
|
| 442 |
-
def _construct_spans_mask(self, span_starts: torch.Tensor, T: int, device: torch.device) -> torch.Tensor:
|
| 443 |
-
"""Build a [1x1xT] boolean mask consists of overlapping spans of True values, where
|
| 444 |
-
span_starts defines the initial index of each span, and the span length is
|
| 445 |
-
defined by self.span_len.
|
| 446 |
-
Args:
|
| 447 |
-
span_starts (torch.Tensor): Boolean mask determines the temporal location of each span start.
|
| 448 |
-
T (int): Sequence length.
|
| 449 |
-
device (torch.device): device of the output tensor.
|
| 450 |
-
Returns:
|
| 451 |
-
torch.Tensor: Spans mask of shape [1x1xT]
|
| 452 |
-
"""
|
| 453 |
-
mask = torch.full((1, 1, T), False, device=device)
|
| 454 |
-
mask[:, :, span_starts] = True
|
| 455 |
-
shifted_mask = mask.clone()
|
| 456 |
-
for _ in range(self.span_len - 1):
|
| 457 |
-
shifted_mask = torch.concat((torch.full((1, 1, 1), False, device=device), shifted_mask[:, :, :-1]), dim=-1)
|
| 458 |
-
mask = torch.logical_or(mask, shifted_mask)
|
| 459 |
-
return mask
|
| 460 |
-
|
| 461 |
-
def _least_probable_span_masking(self, scores: torch.Tensor, num_masked_trg: int) -> torch.Tensor:
|
| 462 |
-
"""Construct a [1x1xT] boolean mask, consists of the u least probable spans,
|
| 463 |
-
where the token probability is determined by -scores, and the total
|
| 464 |
-
number of masked tokens is as closest as possible to num_masked_trg.
|
| 465 |
-
Find u using binary search.
|
| 466 |
-
Args:
|
| 467 |
-
scores (torch.Tensor): Per token score [-log(prob)]
|
| 468 |
-
num_masked_trg: int: The desired amount of tokens to be masked.
|
| 469 |
-
Returns:
|
| 470 |
-
torch.Tensor: Spans mask of shape [1x1xT]
|
| 471 |
-
"""
|
| 472 |
-
T = scores.shape[-1]
|
| 473 |
-
device = scores.device
|
| 474 |
-
scores_unfolded = scores.unfold(2, self.span_len, 1)
|
| 475 |
-
# Span score is the product of probs (sum in log space)
|
| 476 |
-
span_scores = scores_unfolded.sum(dim=-1)
|
| 477 |
-
spans_by_scores = torch.argsort(span_scores[0, 0], descending=True)
|
| 478 |
-
|
| 479 |
-
num_masked_trg = max(num_masked_trg, self.span_len)
|
| 480 |
-
|
| 481 |
-
# Binary search for u - the number least probable overlapping masked spans s.t.
|
| 482 |
-
# the total masking rate is the closest to num_masked_trg / T.
|
| 483 |
-
min_u = num_masked_trg // self.span_len
|
| 484 |
-
max_u = num_masked_trg - self.span_len + 1
|
| 485 |
-
mid = round(0.5 * (min_u + max_u))
|
| 486 |
-
|
| 487 |
-
if mid == min_u or mid == max_u:
|
| 488 |
-
return self._construct_spans_mask(spans_by_scores[:mid], T, device)
|
| 489 |
-
|
| 490 |
-
while mid > min_u and mid < max_u:
|
| 491 |
-
mask = self._construct_spans_mask(spans_by_scores[:mid], T, device)
|
| 492 |
-
n_masked = mask.sum()
|
| 493 |
-
if n_masked > num_masked_trg:
|
| 494 |
-
max_u = mid
|
| 495 |
-
mid = round(0.5 * (min_u + max_u))
|
| 496 |
-
else:
|
| 497 |
-
min_u = mid
|
| 498 |
-
mid = round(0.5 * (min_u + max_u))
|
| 499 |
-
|
| 500 |
-
return mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/models/loaders.py
DELETED
|
@@ -1,291 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Utility functions to load from the checkpoints.
|
| 9 |
-
Each checkpoint is a torch.saved dict with the following keys:
|
| 10 |
-
- 'xp.cfg': the hydra config as dumped during training. This should be used
|
| 11 |
-
to rebuild the object using the audiocraft.models.builders functions,
|
| 12 |
-
- 'model_best_state': a readily loadable best state for the model, including
|
| 13 |
-
the conditioner. The model obtained from `xp.cfg` should be compatible
|
| 14 |
-
with this state dict. In the case of a LM, the encodec model would not be
|
| 15 |
-
bundled along but instead provided separately.
|
| 16 |
-
|
| 17 |
-
Those functions also support loading from a remote location with the Torch Hub API.
|
| 18 |
-
They also support overriding some parameters, in particular the device and dtype
|
| 19 |
-
of the returned model.
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
from pathlib import Path
|
| 23 |
-
from huggingface_hub import hf_hub_download
|
| 24 |
-
import typing as tp
|
| 25 |
-
import os
|
| 26 |
-
|
| 27 |
-
from omegaconf import OmegaConf, DictConfig
|
| 28 |
-
import torch
|
| 29 |
-
|
| 30 |
-
import audiocraft
|
| 31 |
-
|
| 32 |
-
from . import builders
|
| 33 |
-
from .encodec import CompressionModel
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def get_audiocraft_cache_dir() -> tp.Optional[str]:
|
| 37 |
-
return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
HF_MODEL_CHECKPOINTS_MAP = {
|
| 41 |
-
"small": "facebook/musicgen-small",
|
| 42 |
-
"medium": "facebook/musicgen-medium",
|
| 43 |
-
"large": "facebook/musicgen-large",
|
| 44 |
-
"melody": "facebook/musicgen-melody",
|
| 45 |
-
"melody-large": "facebook/musicgen-melody-large",
|
| 46 |
-
"stereo-small": "facebook/musicgen-stereo-small",
|
| 47 |
-
"stereo-medium": "facebook/musicgen-stereo-medium",
|
| 48 |
-
"stereo-large": "facebook/musicgen-stereo-large",
|
| 49 |
-
"stereo-melody": "facebook/musicgen-stereo-melody",
|
| 50 |
-
"stereo-melody-large": "facebook/musicgen-stereo-melody-large",
|
| 51 |
-
"style": "facebook/musicgen-style",
|
| 52 |
-
}
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def _get_state_dict(
|
| 56 |
-
file_or_url_or_id: tp.Union[Path, str],
|
| 57 |
-
filename: tp.Optional[str] = None,
|
| 58 |
-
device='cpu',
|
| 59 |
-
cache_dir: tp.Optional[str] = None,
|
| 60 |
-
):
|
| 61 |
-
if cache_dir is None:
|
| 62 |
-
cache_dir = get_audiocraft_cache_dir()
|
| 63 |
-
# Return the state dict either from a file or url
|
| 64 |
-
file_or_url_or_id = str(file_or_url_or_id)
|
| 65 |
-
assert isinstance(file_or_url_or_id, str)
|
| 66 |
-
|
| 67 |
-
if os.path.isfile(file_or_url_or_id):
|
| 68 |
-
return torch.load(file_or_url_or_id, map_location=device)
|
| 69 |
-
|
| 70 |
-
if os.path.isdir(file_or_url_or_id):
|
| 71 |
-
file = f"{file_or_url_or_id}/{filename}"
|
| 72 |
-
return torch.load(file, map_location=device)
|
| 73 |
-
|
| 74 |
-
elif file_or_url_or_id.startswith('https://'):
|
| 75 |
-
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
|
| 76 |
-
|
| 77 |
-
elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
|
| 78 |
-
assert filename is not None, "filename needs to be defined if using HF checkpoints"
|
| 79 |
-
|
| 80 |
-
repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
|
| 81 |
-
file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
|
| 82 |
-
return torch.load(file, map_location=device)
|
| 83 |
-
|
| 84 |
-
else:
|
| 85 |
-
assert filename is not None, "filename needs to be defined if using HF checkpoints"
|
| 86 |
-
|
| 87 |
-
file = hf_hub_download(
|
| 88 |
-
repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir,
|
| 89 |
-
library_name="audiocraft", library_version=audiocraft.__version__)
|
| 90 |
-
return torch.load(file, map_location=device)
|
| 91 |
-
|
| 92 |
-
def create_melody_config(model_id: str, device: str) -> DictConfig:
|
| 93 |
-
"""Create a fallback configuration for melody models.
|
| 94 |
-
|
| 95 |
-
Args:
|
| 96 |
-
model_id: The model identifier
|
| 97 |
-
device: The device to use
|
| 98 |
-
|
| 99 |
-
Returns:
|
| 100 |
-
A compatible OmegaConf DictConfig
|
| 101 |
-
"""
|
| 102 |
-
base_cfg = {
|
| 103 |
-
"device": str(device),
|
| 104 |
-
"channels": 2 if "stereo" in model_id else 1,
|
| 105 |
-
"sample_rate": 32000,
|
| 106 |
-
"audio_channels": 2 if "stereo" in model_id else 1,
|
| 107 |
-
"frame_rate": 50,
|
| 108 |
-
"codec_name": "encodec",
|
| 109 |
-
"codec": {
|
| 110 |
-
"dim": 128,
|
| 111 |
-
"hidden_dim": 1024,
|
| 112 |
-
"stride": 320,
|
| 113 |
-
"n_q": 4,
|
| 114 |
-
"codebook_size": 2048,
|
| 115 |
-
"normalize": True,
|
| 116 |
-
}
|
| 117 |
-
}
|
| 118 |
-
return OmegaConf.create(base_cfg)
|
| 119 |
-
|
| 120 |
-
def create_default_config(model_id: str, device: str) -> DictConfig:
|
| 121 |
-
"""Create a fallback configuration for standard models.
|
| 122 |
-
|
| 123 |
-
Args:
|
| 124 |
-
model_id: The model identifier
|
| 125 |
-
device: The device to use
|
| 126 |
-
|
| 127 |
-
Returns:
|
| 128 |
-
A compatible OmegaConf DictConfig
|
| 129 |
-
"""
|
| 130 |
-
base_cfg = {
|
| 131 |
-
"device": str(device),
|
| 132 |
-
"channels": 2 if "stereo" in model_id else 1,
|
| 133 |
-
"sample_rate": 32000,
|
| 134 |
-
"audio_channels": 2 if "stereo" in model_id else 1,
|
| 135 |
-
"frame_rate": 50,
|
| 136 |
-
"codec_name": "encodec",
|
| 137 |
-
"codec": {
|
| 138 |
-
"dim": 128,
|
| 139 |
-
"hidden_dim": 1024,
|
| 140 |
-
"stride": 320,
|
| 141 |
-
"n_q": 4,
|
| 142 |
-
"codebook_size": 1024,
|
| 143 |
-
"normalize": True,
|
| 144 |
-
}
|
| 145 |
-
}
|
| 146 |
-
return OmegaConf.create(base_cfg)
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
|
| 150 |
-
return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
| 154 |
-
pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
| 155 |
-
if 'pretrained' in pkg:
|
| 156 |
-
return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
|
| 157 |
-
|
| 158 |
-
# Handle newer model formats that might not have xp.cfg
|
| 159 |
-
if 'xp.cfg' not in pkg:
|
| 160 |
-
if file_or_url_or_id in ['melody-large', 'stereo-melody', 'stereo-medium',
|
| 161 |
-
'stereo-small', 'stereo-large', 'stereo-melody-large','style']:
|
| 162 |
-
print(f"Using fallback configuration for {file_or_url_or_id}")
|
| 163 |
-
# Create a default configuration based on the model type
|
| 164 |
-
# This is where you'd need to add model-specific configurations
|
| 165 |
-
if 'melody' in file_or_url_or_id:
|
| 166 |
-
cfg = create_melody_config(file_or_url_or_id, device)
|
| 167 |
-
else:
|
| 168 |
-
cfg = create_default_config(file_or_url_or_id, device)
|
| 169 |
-
else:
|
| 170 |
-
raise KeyError(f"Missing configuration for model {file_or_url_or_id}")
|
| 171 |
-
else:
|
| 172 |
-
cfg = OmegaConf.create(pkg['xp.cfg'])
|
| 173 |
-
|
| 174 |
-
cfg.device = str(device)
|
| 175 |
-
model = builders.get_compression_model(cfg)
|
| 176 |
-
model.load_state_dict(pkg['best_state'])
|
| 177 |
-
model.eval()
|
| 178 |
-
return model
|
| 179 |
-
|
| 180 |
-
def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
|
| 181 |
-
return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
def _delete_param(cfg: DictConfig, full_name: str):
|
| 185 |
-
parts = full_name.split('.')
|
| 186 |
-
for part in parts[:-1]:
|
| 187 |
-
if part in cfg:
|
| 188 |
-
cfg = cfg[part]
|
| 189 |
-
else:
|
| 190 |
-
return
|
| 191 |
-
OmegaConf.set_struct(cfg, False)
|
| 192 |
-
if parts[-1] in cfg:
|
| 193 |
-
del cfg[parts[-1]]
|
| 194 |
-
OmegaConf.set_struct(cfg, True)
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
| 198 |
-
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
| 199 |
-
cfg = OmegaConf.create(pkg['xp.cfg'])
|
| 200 |
-
cfg.device = str(device)
|
| 201 |
-
if cfg.device == 'cpu':
|
| 202 |
-
cfg.transformer_lm.memory_efficient = False
|
| 203 |
-
cfg.transformer_lm.custom = True
|
| 204 |
-
cfg.dtype = 'float32'
|
| 205 |
-
else:
|
| 206 |
-
cfg.dtype = 'float16'
|
| 207 |
-
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
| 208 |
-
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
| 209 |
-
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
| 210 |
-
model = builders.get_lm_model(cfg)
|
| 211 |
-
model.load_state_dict(pkg['best_state'])
|
| 212 |
-
model.eval()
|
| 213 |
-
model.cfg = cfg
|
| 214 |
-
return model
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_model_frame_rate: int,
|
| 218 |
-
device='cpu', cache_dir: tp.Optional[str] = None):
|
| 219 |
-
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
| 220 |
-
cfg = OmegaConf.create(pkg['xp.cfg'])
|
| 221 |
-
cfg.device = str(device)
|
| 222 |
-
if cfg.device == 'cpu':
|
| 223 |
-
cfg.dtype = 'float32'
|
| 224 |
-
else:
|
| 225 |
-
cfg.dtype = 'float16'
|
| 226 |
-
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
| 227 |
-
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
| 228 |
-
|
| 229 |
-
cfg.transformer_lm.compression_model_framerate = compression_model_frame_rate
|
| 230 |
-
cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration
|
| 231 |
-
cfg.transformer_lm.span_len = cfg.masking.span_len
|
| 232 |
-
|
| 233 |
-
# MAGNeT models v1 support only xformers backend.
|
| 234 |
-
from audiocraft.modules.transformer import set_efficient_attention_backend
|
| 235 |
-
|
| 236 |
-
if cfg.transformer_lm.memory_efficient:
|
| 237 |
-
set_efficient_attention_backend("xformers")
|
| 238 |
-
|
| 239 |
-
model = builders.get_lm_model(cfg)
|
| 240 |
-
model.load_state_dict(pkg['best_state'])
|
| 241 |
-
model.eval()
|
| 242 |
-
model.cfg = cfg
|
| 243 |
-
return model
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
def load_jasco_model(file_or_url_or_id: tp.Union[Path, str],
|
| 247 |
-
compression_model: CompressionModel,
|
| 248 |
-
device='cpu', cache_dir: tp.Optional[str] = None):
|
| 249 |
-
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
| 250 |
-
cfg = OmegaConf.create(pkg['xp.cfg'])
|
| 251 |
-
cfg.device = str(device)
|
| 252 |
-
if cfg.device == 'cpu':
|
| 253 |
-
cfg.dtype = 'float32'
|
| 254 |
-
else:
|
| 255 |
-
cfg.dtype = 'float16'
|
| 256 |
-
model = builders.get_jasco_model(cfg, compression_model)
|
| 257 |
-
model.load_state_dict(pkg['best_state'])
|
| 258 |
-
model.eval()
|
| 259 |
-
model.cfg = cfg
|
| 260 |
-
return model
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
|
| 264 |
-
filename: tp.Optional[str] = None,
|
| 265 |
-
cache_dir: tp.Optional[str] = None):
|
| 266 |
-
return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
|
| 270 |
-
device='cpu',
|
| 271 |
-
filename: tp.Optional[str] = None,
|
| 272 |
-
cache_dir: tp.Optional[str] = None):
|
| 273 |
-
pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
|
| 274 |
-
models = []
|
| 275 |
-
processors = []
|
| 276 |
-
cfgs = []
|
| 277 |
-
sample_rate = pkg['sample_rate']
|
| 278 |
-
for i in range(pkg['n_bands']):
|
| 279 |
-
cfg = pkg[i]['cfg']
|
| 280 |
-
model = builders.get_diffusion_model(cfg)
|
| 281 |
-
model_dict = pkg[i]['model_state']
|
| 282 |
-
model.load_state_dict(model_dict)
|
| 283 |
-
model.to(device)
|
| 284 |
-
processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
|
| 285 |
-
processor_dict = pkg[i]['processor_state']
|
| 286 |
-
processor.load_state_dict(processor_dict)
|
| 287 |
-
processor.to(device)
|
| 288 |
-
models.append(model)
|
| 289 |
-
processors.append(processor)
|
| 290 |
-
cfgs.append(cfg)
|
| 291 |
-
return models, processors, cfgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/models/magnet.py
DELETED
|
@@ -1,88 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Main model for using MAGNeT. This will combine all the required components
|
| 9 |
-
and provide easy access to the generation API.
|
| 10 |
-
"""
|
| 11 |
-
import typing as tp
|
| 12 |
-
import torch
|
| 13 |
-
|
| 14 |
-
from .genmodel import BaseGenModel
|
| 15 |
-
from .loaders import load_compression_model, load_lm_model_magnet
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class MAGNeT(BaseGenModel):
|
| 19 |
-
"""MAGNeT main model with convenient generation API.
|
| 20 |
-
Args:
|
| 21 |
-
See MusicGen class.
|
| 22 |
-
"""
|
| 23 |
-
def __init__(self, **kwargs):
|
| 24 |
-
super().__init__(**kwargs)
|
| 25 |
-
# MAGNeT operates over a fixed sequence length defined in it's config.
|
| 26 |
-
self.duration = self.lm.cfg.dataset.segment_duration
|
| 27 |
-
self.set_generation_params()
|
| 28 |
-
|
| 29 |
-
@staticmethod
|
| 30 |
-
def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None):
|
| 31 |
-
"""Return pretrained model, we provide six models:
|
| 32 |
-
- facebook/magnet-small-10secs (300M), text to music, 10-second audio samples.
|
| 33 |
-
# see: https://huggingface.co/facebook/magnet-small-10secs
|
| 34 |
-
- facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples.
|
| 35 |
-
# see: https://huggingface.co/facebook/magnet-medium-10secs
|
| 36 |
-
- facebook/magnet-small-30secs (300M), text to music, 30-second audio samples.
|
| 37 |
-
# see: https://huggingface.co/facebook/magnet-small-30secs
|
| 38 |
-
- facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples.
|
| 39 |
-
# see: https://huggingface.co/facebook/magnet-medium-30secs
|
| 40 |
-
- facebook/audio-magnet-small (300M), text to sound-effect (10-second samples).
|
| 41 |
-
# see: https://huggingface.co/facebook/audio-magnet-small
|
| 42 |
-
- facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples).
|
| 43 |
-
# see: https://huggingface.co/facebook/audio-magnet-medium
|
| 44 |
-
"""
|
| 45 |
-
if device is None:
|
| 46 |
-
if torch.cuda.device_count():
|
| 47 |
-
device = 'cuda'
|
| 48 |
-
else:
|
| 49 |
-
device = 'cpu'
|
| 50 |
-
|
| 51 |
-
compression_model = load_compression_model(name, device=device)
|
| 52 |
-
lm = load_lm_model_magnet(name, compression_model_frame_rate=int(compression_model.frame_rate), device=device)
|
| 53 |
-
|
| 54 |
-
if 'self_wav' in lm.condition_provider.conditioners:
|
| 55 |
-
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
|
| 56 |
-
|
| 57 |
-
kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm}
|
| 58 |
-
return MAGNeT(**kwargs)
|
| 59 |
-
|
| 60 |
-
def set_generation_params(self, use_sampling: bool = True, top_k: int = 0,
|
| 61 |
-
top_p: float = 0.9, temperature: float = 3.0,
|
| 62 |
-
max_cfg_coef: float = 10.0, min_cfg_coef: float = 1.0,
|
| 63 |
-
decoding_steps: tp.List[int] = [20, 10, 10, 10],
|
| 64 |
-
span_arrangement: str = 'nonoverlap'):
|
| 65 |
-
"""Set the generation parameters for MAGNeT.
|
| 66 |
-
|
| 67 |
-
Args:
|
| 68 |
-
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
| 69 |
-
top_k (int, optional): top_k used for sampling. Defaults to 0.
|
| 70 |
-
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.
|
| 71 |
-
temperature (float, optional): Initial softmax temperature parameter. Defaults to 3.0.
|
| 72 |
-
max_cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 10.0.
|
| 73 |
-
min_cfg_coef (float, optional): End coefficient of classifier free guidance annealing. Defaults to 1.0.
|
| 74 |
-
decoding_steps (list of n_q ints, optional): The number of iterative decoding steps,
|
| 75 |
-
for each of the n_q RVQ codebooks.
|
| 76 |
-
span_arrangement (str, optional): Use either non-overlapping spans ('nonoverlap')
|
| 77 |
-
or overlapping spans ('stride1') in the masking scheme.
|
| 78 |
-
"""
|
| 79 |
-
self.generation_params = {
|
| 80 |
-
'use_sampling': use_sampling,
|
| 81 |
-
'temp': temperature,
|
| 82 |
-
'top_k': top_k,
|
| 83 |
-
'top_p': top_p,
|
| 84 |
-
'max_cfg_coef': max_cfg_coef,
|
| 85 |
-
'min_cfg_coef': min_cfg_coef,
|
| 86 |
-
'decoding_steps': [int(s) for s in decoding_steps],
|
| 87 |
-
'span_arrangement': span_arrangement
|
| 88 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/models/musicgen.py
DELETED
|
@@ -1,566 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Main model for using MusicGen. This will combine all the required components
|
| 9 |
-
and provide easy access to the generation API.
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
import os
|
| 13 |
-
import typing as tp
|
| 14 |
-
import warnings
|
| 15 |
-
|
| 16 |
-
import omegaconf
|
| 17 |
-
import torch
|
| 18 |
-
import gradio as gr
|
| 19 |
-
|
| 20 |
-
from .encodec import CompressionModel
|
| 21 |
-
from .genmodel import BaseGenModel
|
| 22 |
-
from .lm import LMModel
|
| 23 |
-
from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model
|
| 24 |
-
from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
|
| 25 |
-
from ..data.audio_utils import convert_audio
|
| 26 |
-
from ..modules.conditioners import ConditioningAttributes, WavCondition, StyleConditioner
|
| 27 |
-
from ..utils.autocast import TorchAutocast
|
| 28 |
-
|
| 29 |
-
MelodyList = tp.List[tp.Optional[torch.Tensor]]
|
| 30 |
-
MelodyType = tp.Union[torch.Tensor, MelodyList]
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class MusicGen:
|
| 34 |
-
"""MusicGen main model with convenient generation API.
|
| 35 |
-
|
| 36 |
-
Args:
|
| 37 |
-
name (str): name of the model.
|
| 38 |
-
compression_model (CompressionModel): Compression model
|
| 39 |
-
used to map audio to invertible discrete representations.
|
| 40 |
-
lm (LMModel): Language model over discrete representations.
|
| 41 |
-
max_duration (float, optional): maximum duration the model can produce,
|
| 42 |
-
otherwise, inferred from the training params.
|
| 43 |
-
"""
|
| 44 |
-
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: tp.Optional[float] = 30):
|
| 45 |
-
self.name = name
|
| 46 |
-
self.compression_model = compression_model
|
| 47 |
-
self.lm = lm
|
| 48 |
-
self.cfg: tp.Optional[omegaconf.DictConfig] = None
|
| 49 |
-
# Just to be safe, let's put everything in eval mode.
|
| 50 |
-
self.compression_model.eval()
|
| 51 |
-
self.lm.eval()
|
| 52 |
-
|
| 53 |
-
if hasattr(lm, 'cfg'):
|
| 54 |
-
cfg = lm.cfg
|
| 55 |
-
assert isinstance(cfg, omegaconf.DictConfig)
|
| 56 |
-
self.cfg = cfg
|
| 57 |
-
|
| 58 |
-
if self.cfg is not None:
|
| 59 |
-
self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
|
| 60 |
-
|
| 61 |
-
if max_duration is None:
|
| 62 |
-
if self.cfg is not None:
|
| 63 |
-
max_duration = lm.cfg.dataset.segment_duration # type: ignore
|
| 64 |
-
else:
|
| 65 |
-
raise ValueError("You must provide max_duration when building directly MusicGen")
|
| 66 |
-
assert max_duration is not None
|
| 67 |
-
self.max_duration = max_duration
|
| 68 |
-
self.duration = 15.0 # default duration
|
| 69 |
-
self.device = next(iter(lm.parameters())).device
|
| 70 |
-
self.generation_params: dict = {}
|
| 71 |
-
self.set_generation_params(duration=self.duration) # 15 seconds by default
|
| 72 |
-
self._progress_callback: tp.Union[tp.Callable[[int, int], None], gr.Progress] = None
|
| 73 |
-
if self.device.type == 'cpu':
|
| 74 |
-
self.autocast = TorchAutocast(enabled=False)
|
| 75 |
-
else:
|
| 76 |
-
self.autocast = TorchAutocast(
|
| 77 |
-
enabled=True, device_type=self.device.type, dtype=torch.float16)
|
| 78 |
-
|
| 79 |
-
@property
|
| 80 |
-
def version(self) -> str:
|
| 81 |
-
from audiocraft import __version__ as audiocraft_version
|
| 82 |
-
return audiocraft_version
|
| 83 |
-
|
| 84 |
-
@property
|
| 85 |
-
def frame_rate(self) -> float:
|
| 86 |
-
"""Roughly the number of AR steps per seconds."""
|
| 87 |
-
return self.compression_model.frame_rate
|
| 88 |
-
|
| 89 |
-
@property
|
| 90 |
-
def sample_rate(self) -> int:
|
| 91 |
-
"""Sample rate of the generated audio."""
|
| 92 |
-
return self.compression_model.sample_rate
|
| 93 |
-
|
| 94 |
-
@property
|
| 95 |
-
def audio_channels(self) -> int:
|
| 96 |
-
"""Audio channels of the generated audio."""
|
| 97 |
-
return self.compression_model.channels
|
| 98 |
-
|
| 99 |
-
@staticmethod
|
| 100 |
-
def get_pretrained(name: str = 'melody-large', device=None):
|
| 101 |
-
"""Return pretrained model, we provide ten models:
|
| 102 |
-
- small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
|
| 103 |
-
- medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
|
| 104 |
-
- melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
|
| 105 |
-
- large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
|
| 106 |
-
- melody-large (3.3B), text to music, and text+melody to music # see: https://huggingface.co/facebook/musicgen-melody-large
|
| 107 |
-
- stereo-small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
|
| 108 |
-
- stereo-medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-stereo-medium
|
| 109 |
-
- stereo-melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-stereo-melody
|
| 110 |
-
- stereo-large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-stereo-large
|
| 111 |
-
- stereo-melody-large (3.3B), text to music, and text+melody to music # see: https://huggingface.co/facebook/musicgen-stereo-melody-large
|
| 112 |
-
- musicgen-style (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-style
|
| 113 |
-
"""
|
| 114 |
-
|
| 115 |
-
if device is None:
|
| 116 |
-
if torch.cuda.device_count():
|
| 117 |
-
device = 'cuda'
|
| 118 |
-
else:
|
| 119 |
-
device = 'cpu'
|
| 120 |
-
|
| 121 |
-
if name == 'debug':
|
| 122 |
-
# used only for unit tests
|
| 123 |
-
compression_model = get_debug_compression_model(device)
|
| 124 |
-
lm = get_debug_lm_model(device)
|
| 125 |
-
return MusicGen(name, compression_model, lm, max_duration=30)
|
| 126 |
-
|
| 127 |
-
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
| 128 |
-
if not os.path.isfile(name) and not os.path.isdir(name):
|
| 129 |
-
raise ValueError(
|
| 130 |
-
f"{name} is not a valid checkpoint name. "
|
| 131 |
-
f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
|
| 132 |
-
)
|
| 133 |
-
else:
|
| 134 |
-
name = HF_MODEL_CHECKPOINTS_MAP[name]
|
| 135 |
-
|
| 136 |
-
cache_dir = os.environ.get('MUSICGEN_ROOT', None)
|
| 137 |
-
compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
|
| 138 |
-
lm = load_lm_model(name, device=device, cache_dir=cache_dir)
|
| 139 |
-
if name.__contains__('melody') or 'self_wav' in lm.condition_provider.conditioners:
|
| 140 |
-
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
|
| 141 |
-
lm.condition_provider.conditioners['self_wav']._use_masking = False
|
| 142 |
-
|
| 143 |
-
return MusicGen(name, compression_model, lm)
|
| 144 |
-
|
| 145 |
-
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
| 146 |
-
top_p: float = 0.0, temperature: float = 1.0,
|
| 147 |
-
duration: float = 30.0, cfg_coef: float = 3.0,
|
| 148 |
-
cfg_coef_beta: tp.Optional[float] = None,
|
| 149 |
-
two_step_cfg: bool = False, extend_stride: float = 10, rep_penalty: float = None):
|
| 150 |
-
"""Set the generation parameters for MusicGen.
|
| 151 |
-
|
| 152 |
-
Args:
|
| 153 |
-
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
| 154 |
-
top_k (int, optional): top_k used for sampling. Defaults to 250.
|
| 155 |
-
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
|
| 156 |
-
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
|
| 157 |
-
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
|
| 158 |
-
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
|
| 159 |
-
cfg_coef_beta (float, optional): beta coefficient in double classifier free guidance.
|
| 160 |
-
Should be only used for MusicGen melody if we want to push the text condition more than
|
| 161 |
-
the audio conditioning. See paragraph 4.3 in https://arxiv.org/pdf/2407.12563 to understand
|
| 162 |
-
double CFG.
|
| 163 |
-
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
| 164 |
-
instead of batching together the two. This has some impact on how things
|
| 165 |
-
are padded but seems to have little impact in practice.
|
| 166 |
-
extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
|
| 167 |
-
should we extend the audio each time. Larger values will mean less context is
|
| 168 |
-
preserved, and shorter value will require extra computations.
|
| 169 |
-
rep_penalty (float, optional): If set, use repetition penalty during generation. Not Implemented.
|
| 170 |
-
"""
|
| 171 |
-
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
| 172 |
-
self.extend_stride = extend_stride
|
| 173 |
-
self.duration = duration
|
| 174 |
-
self.generation_params = {
|
| 175 |
-
#'max_gen_len': int(duration * self.frame_rate),
|
| 176 |
-
'use_sampling': use_sampling,
|
| 177 |
-
'temp': temperature,
|
| 178 |
-
'top_k': top_k,
|
| 179 |
-
'top_p': top_p,
|
| 180 |
-
'cfg_coef': cfg_coef,
|
| 181 |
-
'two_step_cfg': two_step_cfg,
|
| 182 |
-
'cfg_coef_beta': cfg_coef_beta,
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
def set_style_conditioner_params(self, eval_q: int = 3, excerpt_length: float = 3.0,
|
| 186 |
-
ds_factor: tp.Optional[int] = None,
|
| 187 |
-
encodec_n_q: tp.Optional[int] = None) -> None:
|
| 188 |
-
"""Set the parameters of the style conditioner
|
| 189 |
-
Args:
|
| 190 |
-
eval_q (int): the number of residual quantization streams used to quantize the style condition
|
| 191 |
-
the smaller it is, the narrower is the information bottleneck
|
| 192 |
-
excerpt_length (float): the excerpt length in seconds that is extracted from the audio
|
| 193 |
-
conditioning
|
| 194 |
-
ds_factor: (int): the downsampling factor used to downsample the style tokens before
|
| 195 |
-
using them as a prefix
|
| 196 |
-
encodec_n_q: (int, optional): if encodec is used as a feature extractor, sets the number
|
| 197 |
-
of streams that is used to extract features
|
| 198 |
-
"""
|
| 199 |
-
assert isinstance(self.lm.condition_provider.conditioners.self_wav, StyleConditioner), \
|
| 200 |
-
"Only use this function if you model is MusicGen-Style"
|
| 201 |
-
self.lm.condition_provider.conditioners.self_wav.set_params(eval_q=eval_q,
|
| 202 |
-
excerpt_length=excerpt_length,
|
| 203 |
-
ds_factor=ds_factor,
|
| 204 |
-
encodec_n_q=encodec_n_q)
|
| 205 |
-
|
| 206 |
-
def set_custom_progress_callback(self, progress_callback: tp.Union[tp.Callable[[int, int], None],gr.Progress] = None):
|
| 207 |
-
"""Override the default progress callback."""
|
| 208 |
-
self._progress_callback = progress_callback
|
| 209 |
-
|
| 210 |
-
def generate_unconditional(self, num_samples: int, progress: bool = False,
|
| 211 |
-
return_tokens: bool = False, progress_callback: gr.Progress = None) -> tp.Union[torch.Tensor,
|
| 212 |
-
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
| 213 |
-
"""Generate samples in an unconditional manner.
|
| 214 |
-
|
| 215 |
-
Args:
|
| 216 |
-
num_samples (int): Number of samples to be generated.
|
| 217 |
-
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 218 |
-
return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.
|
| 219 |
-
"""
|
| 220 |
-
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
|
| 221 |
-
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
| 222 |
-
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
| 223 |
-
if return_tokens:
|
| 224 |
-
return self.generate_audio(tokens), tokens
|
| 225 |
-
return self.generate_audio(tokens)
|
| 226 |
-
|
| 227 |
-
def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False, progress_callback: gr.Progress = None) \
|
| 228 |
-
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
| 229 |
-
"""Generate samples conditioned on text.
|
| 230 |
-
|
| 231 |
-
Args:
|
| 232 |
-
descriptions (list of str): A list of strings used as text conditioning.
|
| 233 |
-
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 234 |
-
return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.
|
| 235 |
-
"""
|
| 236 |
-
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
| 237 |
-
assert prompt_tokens is None
|
| 238 |
-
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
| 239 |
-
if return_tokens:
|
| 240 |
-
return self.generate_audio(tokens), tokens
|
| 241 |
-
return self.generate_audio(tokens)
|
| 242 |
-
|
| 243 |
-
def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
| 244 |
-
melody_sample_rate: int, progress: bool = False,
|
| 245 |
-
return_tokens: bool = False, progress_callback=gr.Progress(track_tqdm=True)) -> tp.Union[torch.Tensor,
|
| 246 |
-
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
| 247 |
-
"""Generate samples conditioned on text and melody.
|
| 248 |
-
|
| 249 |
-
Args:
|
| 250 |
-
descriptions (list of str): A list of strings used as text conditioning.
|
| 251 |
-
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
|
| 252 |
-
melody conditioning. Should have shape [B, C, T] with B matching the description length,
|
| 253 |
-
C=1 or 2. It can be [C, T] if there is a single description. It can also be
|
| 254 |
-
a list of [C, T] tensors.
|
| 255 |
-
melody_sample_rate: (int): Sample rate of the melody waveforms.
|
| 256 |
-
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 257 |
-
return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.
|
| 258 |
-
"""
|
| 259 |
-
if isinstance(melody_wavs, torch.Tensor):
|
| 260 |
-
if melody_wavs.dim() == 2:
|
| 261 |
-
melody_wavs = melody_wavs[None]
|
| 262 |
-
if melody_wavs.dim() != 3:
|
| 263 |
-
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
| 264 |
-
melody_wavs = list(melody_wavs)
|
| 265 |
-
else:
|
| 266 |
-
for melody in melody_wavs:
|
| 267 |
-
if melody is not None:
|
| 268 |
-
assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
|
| 269 |
-
|
| 270 |
-
melody_wavs = [
|
| 271 |
-
convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
|
| 272 |
-
if wav is not None else None
|
| 273 |
-
for wav in melody_wavs]
|
| 274 |
-
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
|
| 275 |
-
melody_wavs=melody_wavs)
|
| 276 |
-
assert prompt_tokens is None
|
| 277 |
-
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
| 278 |
-
if return_tokens:
|
| 279 |
-
return self.generate_audio(tokens), tokens
|
| 280 |
-
return self.generate_audio(tokens)
|
| 281 |
-
|
| 282 |
-
def generate_with_all(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
| 283 |
-
sample_rate: int, progress: bool = False, prompt: tp.Optional[torch.Tensor] = None, return_tokens: bool = False, progress_callback: gr.Progress = None) \
|
| 284 |
-
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
| 285 |
-
"""Generate samples conditioned on text and melody and audio prompts.
|
| 286 |
-
Args:
|
| 287 |
-
descriptions (tp.List[str]): A list of strings used as text conditioning.
|
| 288 |
-
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
|
| 289 |
-
melody conditioning. Should have shape [B, C, T] with B matching the description length,
|
| 290 |
-
C=1 or 2. It can be [C, T] if there is a single description. It can also be
|
| 291 |
-
a list of [C, T] tensors.
|
| 292 |
-
sample_rate: (int): Sample rate of the melody waveforms.
|
| 293 |
-
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 294 |
-
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
| 295 |
-
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
| 296 |
-
"""
|
| 297 |
-
if isinstance(melody_wavs, torch.Tensor):
|
| 298 |
-
if melody_wavs.dim() == 2:
|
| 299 |
-
melody_wavs = melody_wavs[None]
|
| 300 |
-
if melody_wavs.dim() != 3:
|
| 301 |
-
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
| 302 |
-
melody_wavs = list(melody_wavs)
|
| 303 |
-
else:
|
| 304 |
-
for melody in melody_wavs:
|
| 305 |
-
if melody is not None:
|
| 306 |
-
assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
|
| 307 |
-
|
| 308 |
-
melody_wavs = [
|
| 309 |
-
convert_audio(wav, sample_rate, self.sample_rate, self.audio_channels)
|
| 310 |
-
if wav is not None else None
|
| 311 |
-
for wav in melody_wavs]
|
| 312 |
-
#attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
|
| 313 |
-
# melody_wavs=melody_wavs)
|
| 314 |
-
|
| 315 |
-
if prompt is not None:
|
| 316 |
-
if prompt.dim() == 2:
|
| 317 |
-
prompt = prompt[None]
|
| 318 |
-
if prompt.dim() != 3:
|
| 319 |
-
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
|
| 320 |
-
prompt = convert_audio(prompt, sample_rate, self.sample_rate, self.audio_channels)
|
| 321 |
-
if descriptions is None:
|
| 322 |
-
descriptions = [None] * len(prompt)
|
| 323 |
-
|
| 324 |
-
#if prompt is not None:
|
| 325 |
-
# attributes_gen, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
| 326 |
-
|
| 327 |
-
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=prompt,
|
| 328 |
-
melody_wavs=melody_wavs)
|
| 329 |
-
if prompt is not None:
|
| 330 |
-
assert prompt_tokens is not None
|
| 331 |
-
else:
|
| 332 |
-
assert prompt_tokens is None
|
| 333 |
-
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
| 334 |
-
if return_tokens:
|
| 335 |
-
return self.generate_audio(tokens), tokens
|
| 336 |
-
return self.generate_audio(tokens)
|
| 337 |
-
|
| 338 |
-
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
| 339 |
-
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
| 340 |
-
progress: bool = False, return_tokens: bool = False, progress_callback: gr.Progress = None) \
|
| 341 |
-
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
| 342 |
-
"""Generate samples conditioned on audio prompts.
|
| 343 |
-
|
| 344 |
-
Args:
|
| 345 |
-
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
| 346 |
-
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
| 347 |
-
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
|
| 348 |
-
descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
|
| 349 |
-
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 350 |
-
return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.\
|
| 351 |
-
This is truly a hack and does not follow the progression of conditioning melody or previously generated audio.
|
| 352 |
-
"""
|
| 353 |
-
if prompt.dim() == 2:
|
| 354 |
-
prompt = prompt[None]
|
| 355 |
-
if prompt.dim() != 3:
|
| 356 |
-
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
|
| 357 |
-
prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
|
| 358 |
-
if descriptions is None:
|
| 359 |
-
descriptions = [None] * len(prompt)
|
| 360 |
-
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
| 361 |
-
assert prompt_tokens is not None
|
| 362 |
-
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
| 363 |
-
if return_tokens:
|
| 364 |
-
return self.generate_audio(tokens), tokens
|
| 365 |
-
return self.generate_audio(tokens)
|
| 366 |
-
|
| 367 |
-
@torch.no_grad()
|
| 368 |
-
def _prepare_tokens_and_attributes(
|
| 369 |
-
self,
|
| 370 |
-
descriptions: tp.Sequence[tp.Optional[str]],
|
| 371 |
-
prompt: tp.Optional[torch.Tensor],
|
| 372 |
-
melody_wavs: tp.Optional[MelodyList] = None,
|
| 373 |
-
progress_callback: tp.Optional[gr.Progress] = None
|
| 374 |
-
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
|
| 375 |
-
"""Prepare model inputs.
|
| 376 |
-
|
| 377 |
-
Args:
|
| 378 |
-
descriptions (list of str): A list of strings used as text conditioning.
|
| 379 |
-
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
| 380 |
-
melody_wavs (torch.Tensor, optional): A batch of waveforms
|
| 381 |
-
used as melody conditioning. Defaults to None.
|
| 382 |
-
"""
|
| 383 |
-
attributes = [
|
| 384 |
-
ConditioningAttributes(text={'description': description})
|
| 385 |
-
for description in descriptions]
|
| 386 |
-
|
| 387 |
-
if melody_wavs is None:
|
| 388 |
-
for attr in attributes:
|
| 389 |
-
attr.wav['self_wav'] = WavCondition(
|
| 390 |
-
torch.zeros((1, 1, 1), device=self.device),
|
| 391 |
-
torch.tensor([0], device=self.device),
|
| 392 |
-
sample_rate=[self.sample_rate],
|
| 393 |
-
path=[None]) # type: ignore
|
| 394 |
-
else:
|
| 395 |
-
if 'self_wav' not in self.lm.condition_provider.conditioners:
|
| 396 |
-
raise RuntimeError("This model doesn't support melody conditioning. "
|
| 397 |
-
"Use the `melody` model.")
|
| 398 |
-
assert len(melody_wavs) == len(descriptions), \
|
| 399 |
-
f"number of melody wavs must match number of descriptions! " \
|
| 400 |
-
f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
|
| 401 |
-
for attr, melody in zip(attributes, melody_wavs):
|
| 402 |
-
if melody is None:
|
| 403 |
-
attr.wav['self_wav'] = WavCondition(
|
| 404 |
-
torch.zeros((1, 1, 1), device=self.device),
|
| 405 |
-
torch.tensor([0], device=self.device),
|
| 406 |
-
sample_rate=[self.sample_rate],
|
| 407 |
-
path=[None]) # type: ignore
|
| 408 |
-
else:
|
| 409 |
-
attr.wav['self_wav'] = WavCondition(
|
| 410 |
-
melody[None].to(device=self.device),
|
| 411 |
-
torch.tensor([melody.shape[-1]], device=self.device),
|
| 412 |
-
sample_rate=[self.sample_rate],
|
| 413 |
-
path=[None],
|
| 414 |
-
)
|
| 415 |
-
|
| 416 |
-
if prompt is not None:
|
| 417 |
-
if descriptions is not None:
|
| 418 |
-
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
|
| 419 |
-
prompt = prompt.to(self.device)
|
| 420 |
-
prompt_tokens, scale = self.compression_model.encode(prompt)
|
| 421 |
-
assert scale is None
|
| 422 |
-
else:
|
| 423 |
-
prompt_tokens = None
|
| 424 |
-
return attributes, prompt_tokens
|
| 425 |
-
|
| 426 |
-
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
| 427 |
-
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False, progress_callback: gr.Progress = None) -> torch.Tensor:
|
| 428 |
-
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
| 429 |
-
|
| 430 |
-
Args:
|
| 431 |
-
attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
|
| 432 |
-
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
| 433 |
-
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 434 |
-
Returns:
|
| 435 |
-
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
| 436 |
-
"""
|
| 437 |
-
total_gen_len = int(self.duration * self.frame_rate)
|
| 438 |
-
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
| 439 |
-
current_gen_offset: int = 0
|
| 440 |
-
|
| 441 |
-
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
| 442 |
-
generated_tokens += current_gen_offset
|
| 443 |
-
generated_tokens /= ((tokens_to_generate) / self.duration)
|
| 444 |
-
tokens_to_generate /= ((tokens_to_generate) / self.duration)
|
| 445 |
-
if self._progress_callback is not None:
|
| 446 |
-
# Note that total_gen_len might be quite wrong depending on the
|
| 447 |
-
# codebook pattern used, but with delay it is almost accurate.
|
| 448 |
-
self._progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens: 6.2f}/{tokens_to_generate: 6.2f} seconds")
|
| 449 |
-
if progress_callback is not None:
|
| 450 |
-
# Update Gradio progress bar
|
| 451 |
-
progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens: 6.2f}/{tokens_to_generate: 6.2f} seconds")
|
| 452 |
-
if progress:
|
| 453 |
-
print(f'{generated_tokens: 6.2f} / {tokens_to_generate: 6.2f}', end='\r')
|
| 454 |
-
|
| 455 |
-
if prompt_tokens is not None:
|
| 456 |
-
if prompt_tokens.shape[-1] > max_prompt_len:
|
| 457 |
-
prompt_tokens = prompt_tokens[..., :max_prompt_len]
|
| 458 |
-
|
| 459 |
-
# callback = None
|
| 460 |
-
callback = _progress_callback
|
| 461 |
-
|
| 462 |
-
if self.duration <= self.max_duration:
|
| 463 |
-
# generate by sampling from LM, simple case.
|
| 464 |
-
with self.autocast:
|
| 465 |
-
gen_tokens = self.lm.generate(
|
| 466 |
-
prompt_tokens, attributes,
|
| 467 |
-
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
|
| 468 |
-
|
| 469 |
-
else:
|
| 470 |
-
# now this gets a bit messier, we need to handle prompts,
|
| 471 |
-
# melody conditioning etc.
|
| 472 |
-
ref_wavs = [attr.wav['self_wav'] for attr in attributes]
|
| 473 |
-
all_tokens = []
|
| 474 |
-
if prompt_tokens is None:
|
| 475 |
-
prompt_length = 0
|
| 476 |
-
else:
|
| 477 |
-
all_tokens.append(prompt_tokens)
|
| 478 |
-
prompt_length = prompt_tokens.shape[-1]
|
| 479 |
-
|
| 480 |
-
stride_tokens = int(self.frame_rate * self.extend_stride)
|
| 481 |
-
|
| 482 |
-
while current_gen_offset + prompt_length < total_gen_len:
|
| 483 |
-
time_offset = current_gen_offset / self.frame_rate
|
| 484 |
-
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
| 485 |
-
max_gen_len = int(chunk_duration * self.frame_rate)
|
| 486 |
-
for attr, ref_wav in zip(attributes, ref_wavs):
|
| 487 |
-
wav_length = ref_wav.length.item()
|
| 488 |
-
if wav_length == 0:
|
| 489 |
-
continue
|
| 490 |
-
# We will extend the wav periodically if it not long enough.
|
| 491 |
-
# we have to do it here rather than in conditioners.py as otherwise
|
| 492 |
-
# we wouldn't have the full wav.
|
| 493 |
-
initial_position = int(time_offset * self.sample_rate)
|
| 494 |
-
wav_target_length = int(self.max_duration * self.sample_rate)
|
| 495 |
-
print(initial_position / self.sample_rate, wav_target_length / self.sample_rate)
|
| 496 |
-
positions = torch.arange(initial_position,
|
| 497 |
-
initial_position + wav_target_length, device=self.device)
|
| 498 |
-
attr.wav['self_wav'] = WavCondition(
|
| 499 |
-
ref_wav[0][..., positions % wav_length],
|
| 500 |
-
torch.full_like(ref_wav[1], wav_target_length),
|
| 501 |
-
[self.sample_rate] * ref_wav[0].size(0),
|
| 502 |
-
[None], [0.])
|
| 503 |
-
with self.autocast:
|
| 504 |
-
gen_tokens = self.lm.generate(
|
| 505 |
-
prompt_tokens, attributes,
|
| 506 |
-
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
| 507 |
-
if prompt_tokens is None:
|
| 508 |
-
all_tokens.append(gen_tokens)
|
| 509 |
-
else:
|
| 510 |
-
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
| 511 |
-
prompt_tokens = gen_tokens[:, :, stride_tokens:]
|
| 512 |
-
prompt_length = prompt_tokens.shape[-1]
|
| 513 |
-
current_gen_offset += stride_tokens
|
| 514 |
-
|
| 515 |
-
gen_tokens = torch.cat(all_tokens, dim=-1)
|
| 516 |
-
return gen_tokens
|
| 517 |
-
|
| 518 |
-
# generate audio
|
| 519 |
-
|
| 520 |
-
def generate_audio(self, gen_tokens: torch.Tensor):
|
| 521 |
-
try:
|
| 522 |
-
"""Generate Audio from tokens"""
|
| 523 |
-
assert gen_tokens.dim() == 3
|
| 524 |
-
with torch.no_grad():
|
| 525 |
-
gen_audio = self.compression_model.decode(gen_tokens, None)
|
| 526 |
-
return gen_audio
|
| 527 |
-
except Exception as e:
|
| 528 |
-
print(f"Error generating audio: {e}")
|
| 529 |
-
return None
|
| 530 |
-
|
| 531 |
-
#def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
| 532 |
-
# prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
| 533 |
-
# """Generate discrete audio tokens given audio prompt and/or conditions.
|
| 534 |
-
|
| 535 |
-
# Args:
|
| 536 |
-
# attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody).
|
| 537 |
-
# prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation.
|
| 538 |
-
# progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 539 |
-
# Returns:
|
| 540 |
-
# torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
| 541 |
-
# """
|
| 542 |
-
# def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
| 543 |
-
# print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
|
| 544 |
-
|
| 545 |
-
# if prompt_tokens is not None:
|
| 546 |
-
# assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
|
| 547 |
-
# "Prompt is longer than audio to generate"
|
| 548 |
-
|
| 549 |
-
# callback = None
|
| 550 |
-
# if progress:
|
| 551 |
-
# callback = _progress_callback
|
| 552 |
-
|
| 553 |
-
# # generate by sampling from LM
|
| 554 |
-
# with self.autocast:
|
| 555 |
-
# gen_tokens = self.lm.generate(prompt_tokens, attributes, callback=callback, **self.generation_params)
|
| 556 |
-
|
| 557 |
-
# # generate audio
|
| 558 |
-
# assert gen_tokens.dim() == 3
|
| 559 |
-
# with torch.no_grad():
|
| 560 |
-
# gen_audio = self.compression_model.decode(gen_tokens, None)
|
| 561 |
-
# return gen_audio
|
| 562 |
-
|
| 563 |
-
def to(self, device: str):
|
| 564 |
-
self.compression_model.to(device)
|
| 565 |
-
self.lm.to(device)
|
| 566 |
-
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/models/unet.py
DELETED
|
@@ -1,214 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Pytorch Unet Module used for diffusion.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
from dataclasses import dataclass
|
| 12 |
-
import typing as tp
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
from torch import nn
|
| 16 |
-
from torch.nn import functional as F
|
| 17 |
-
from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
@dataclass
|
| 21 |
-
class Output:
|
| 22 |
-
sample: torch.Tensor
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def get_model(cfg, channels: int, side: int, num_steps: int):
|
| 26 |
-
if cfg.model == 'unet':
|
| 27 |
-
return DiffusionUnet(
|
| 28 |
-
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
| 29 |
-
else:
|
| 30 |
-
raise RuntimeError('Not Implemented')
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class ResBlock(nn.Module):
|
| 34 |
-
def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
|
| 35 |
-
dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
| 36 |
-
dropout: float = 0.):
|
| 37 |
-
super().__init__()
|
| 38 |
-
stride = 1
|
| 39 |
-
padding = dilation * (kernel - stride) // 2
|
| 40 |
-
Conv = nn.Conv1d
|
| 41 |
-
Drop = nn.Dropout1d
|
| 42 |
-
self.norm1 = nn.GroupNorm(norm_groups, channels)
|
| 43 |
-
self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
|
| 44 |
-
self.activation1 = activation()
|
| 45 |
-
self.dropout1 = Drop(dropout)
|
| 46 |
-
|
| 47 |
-
self.norm2 = nn.GroupNorm(norm_groups, channels)
|
| 48 |
-
self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
|
| 49 |
-
self.activation2 = activation()
|
| 50 |
-
self.dropout2 = Drop(dropout)
|
| 51 |
-
|
| 52 |
-
def forward(self, x):
|
| 53 |
-
h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
|
| 54 |
-
h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
|
| 55 |
-
return x + h
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
class DecoderLayer(nn.Module):
|
| 59 |
-
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
|
| 60 |
-
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
| 61 |
-
dropout: float = 0.):
|
| 62 |
-
super().__init__()
|
| 63 |
-
padding = (kernel - stride) // 2
|
| 64 |
-
self.res_blocks = nn.Sequential(
|
| 65 |
-
*[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
|
| 66 |
-
for idx in range(res_blocks)])
|
| 67 |
-
self.norm = nn.GroupNorm(norm_groups, chin)
|
| 68 |
-
ConvTr = nn.ConvTranspose1d
|
| 69 |
-
self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
|
| 70 |
-
self.activation = activation()
|
| 71 |
-
|
| 72 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
-
x = self.res_blocks(x)
|
| 74 |
-
x = self.norm(x)
|
| 75 |
-
x = self.activation(x)
|
| 76 |
-
x = self.convtr(x)
|
| 77 |
-
return x
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class EncoderLayer(nn.Module):
|
| 81 |
-
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
|
| 82 |
-
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
| 83 |
-
dropout: float = 0.):
|
| 84 |
-
super().__init__()
|
| 85 |
-
padding = (kernel - stride) // 2
|
| 86 |
-
Conv = nn.Conv1d
|
| 87 |
-
self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
|
| 88 |
-
self.norm = nn.GroupNorm(norm_groups, chout)
|
| 89 |
-
self.activation = activation()
|
| 90 |
-
self.res_blocks = nn.Sequential(
|
| 91 |
-
*[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
|
| 92 |
-
for idx in range(res_blocks)])
|
| 93 |
-
|
| 94 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 95 |
-
B, C, T = x.shape
|
| 96 |
-
stride, = self.conv.stride
|
| 97 |
-
pad = (stride - (T % stride)) % stride
|
| 98 |
-
x = F.pad(x, (0, pad))
|
| 99 |
-
|
| 100 |
-
x = self.conv(x)
|
| 101 |
-
x = self.norm(x)
|
| 102 |
-
x = self.activation(x)
|
| 103 |
-
x = self.res_blocks(x)
|
| 104 |
-
return x
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
class BLSTM(nn.Module):
|
| 108 |
-
"""BiLSTM with same hidden units as input dim.
|
| 109 |
-
"""
|
| 110 |
-
def __init__(self, dim, layers=2):
|
| 111 |
-
super().__init__()
|
| 112 |
-
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
| 113 |
-
self.linear = nn.Linear(2 * dim, dim)
|
| 114 |
-
|
| 115 |
-
def forward(self, x):
|
| 116 |
-
x = x.permute(2, 0, 1)
|
| 117 |
-
x = self.lstm(x)[0]
|
| 118 |
-
x = self.linear(x)
|
| 119 |
-
x = x.permute(1, 2, 0)
|
| 120 |
-
return x
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
class DiffusionUnet(nn.Module):
|
| 124 |
-
def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
|
| 125 |
-
max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
|
| 126 |
-
bilstm: bool = False, transformer: bool = False,
|
| 127 |
-
codec_dim: tp.Optional[int] = None, **kwargs):
|
| 128 |
-
super().__init__()
|
| 129 |
-
self.encoders = nn.ModuleList()
|
| 130 |
-
self.decoders = nn.ModuleList()
|
| 131 |
-
self.embeddings: tp.Optional[nn.ModuleList] = None
|
| 132 |
-
self.embedding = nn.Embedding(num_steps, hidden)
|
| 133 |
-
if emb_all_layers:
|
| 134 |
-
self.embeddings = nn.ModuleList()
|
| 135 |
-
self.condition_embedding: tp.Optional[nn.Module] = None
|
| 136 |
-
for d in range(depth):
|
| 137 |
-
encoder = EncoderLayer(chin, hidden, **kwargs)
|
| 138 |
-
decoder = DecoderLayer(hidden, chin, **kwargs)
|
| 139 |
-
self.encoders.append(encoder)
|
| 140 |
-
self.decoders.insert(0, decoder)
|
| 141 |
-
if emb_all_layers and d > 0:
|
| 142 |
-
assert self.embeddings is not None
|
| 143 |
-
self.embeddings.append(nn.Embedding(num_steps, hidden))
|
| 144 |
-
chin = hidden
|
| 145 |
-
hidden = min(int(chin * growth), max_channels)
|
| 146 |
-
self.bilstm: tp.Optional[nn.Module]
|
| 147 |
-
if bilstm:
|
| 148 |
-
self.bilstm = BLSTM(chin)
|
| 149 |
-
else:
|
| 150 |
-
self.bilstm = None
|
| 151 |
-
self.use_transformer = transformer
|
| 152 |
-
self.cross_attention = False
|
| 153 |
-
if transformer:
|
| 154 |
-
self.cross_attention = cross_attention
|
| 155 |
-
self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
|
| 156 |
-
cross_attention=cross_attention)
|
| 157 |
-
|
| 158 |
-
self.use_codec = False
|
| 159 |
-
if codec_dim is not None:
|
| 160 |
-
self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
|
| 161 |
-
self.use_codec = True
|
| 162 |
-
|
| 163 |
-
def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
|
| 164 |
-
skips = []
|
| 165 |
-
bs = x.size(0)
|
| 166 |
-
z = x
|
| 167 |
-
view_args = [1]
|
| 168 |
-
if type(step) is torch.Tensor:
|
| 169 |
-
step_tensor = step
|
| 170 |
-
else:
|
| 171 |
-
step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
|
| 172 |
-
|
| 173 |
-
for idx, encoder in enumerate(self.encoders):
|
| 174 |
-
z = encoder(z)
|
| 175 |
-
if idx == 0:
|
| 176 |
-
z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
|
| 177 |
-
elif self.embeddings is not None:
|
| 178 |
-
z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
|
| 179 |
-
|
| 180 |
-
skips.append(z)
|
| 181 |
-
|
| 182 |
-
if self.use_codec: # insert condition in the bottleneck
|
| 183 |
-
assert condition is not None, "Model defined for conditionnal generation"
|
| 184 |
-
condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
|
| 185 |
-
assert condition_emb.size(-1) <= 2 * z.size(-1), \
|
| 186 |
-
f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
|
| 187 |
-
if not self.cross_attention:
|
| 188 |
-
|
| 189 |
-
condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
|
| 190 |
-
assert z.size() == condition_emb.size()
|
| 191 |
-
z += condition_emb
|
| 192 |
-
cross_attention_src = None
|
| 193 |
-
else:
|
| 194 |
-
cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
|
| 195 |
-
B, T, C = cross_attention_src.shape
|
| 196 |
-
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
| 197 |
-
pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
|
| 198 |
-
cross_attention_src = cross_attention_src + pos_emb
|
| 199 |
-
if self.use_transformer:
|
| 200 |
-
z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
|
| 201 |
-
else:
|
| 202 |
-
if self.bilstm is None:
|
| 203 |
-
z = torch.zeros_like(z)
|
| 204 |
-
else:
|
| 205 |
-
z = self.bilstm(z)
|
| 206 |
-
|
| 207 |
-
for decoder in self.decoders:
|
| 208 |
-
s = skips.pop(-1)
|
| 209 |
-
z = z[:, :, :s.shape[2]]
|
| 210 |
-
z = z + s
|
| 211 |
-
z = decoder(z)
|
| 212 |
-
|
| 213 |
-
z = z[:, :, :x.shape[2]]
|
| 214 |
-
return Output(z)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/__init__.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
# flake8: noqa
|
| 8 |
-
from .conv import (
|
| 9 |
-
NormConv1d,
|
| 10 |
-
NormConv2d,
|
| 11 |
-
NormConvTranspose1d,
|
| 12 |
-
NormConvTranspose2d,
|
| 13 |
-
StreamableConv1d,
|
| 14 |
-
StreamableConvTranspose1d,
|
| 15 |
-
pad_for_conv1d,
|
| 16 |
-
pad1d,
|
| 17 |
-
unpad1d,
|
| 18 |
-
)
|
| 19 |
-
from .lstm import StreamableLSTM
|
| 20 |
-
from .seanet import SEANetEncoder, SEANetDecoder
|
| 21 |
-
from .transformer import StreamingTransformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/activations.py
DELETED
|
@@ -1,96 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
from torch import Tensor
|
| 10 |
-
from typing import Union, Callable
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class CustomGLU(nn.Module):
|
| 14 |
-
"""Custom Gated Linear Unit activation.
|
| 15 |
-
Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
|
| 16 |
-
of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
|
| 17 |
-
function (i.e. sigmoid, swish, etc.).
|
| 18 |
-
|
| 19 |
-
Args:
|
| 20 |
-
activation (nn.Module): The custom activation to apply in the Gated Linear Unit
|
| 21 |
-
dim (int): the dimension on which to split the input. Default: -1
|
| 22 |
-
|
| 23 |
-
Shape:
|
| 24 |
-
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
| 25 |
-
dimensions
|
| 26 |
-
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
| 27 |
-
|
| 28 |
-
Examples::
|
| 29 |
-
>>> m = CustomGLU(nn.Sigmoid())
|
| 30 |
-
>>> input = torch.randn(4, 2)
|
| 31 |
-
>>> output = m(input)
|
| 32 |
-
"""
|
| 33 |
-
def __init__(self, activation: nn.Module, dim: int = -1):
|
| 34 |
-
super(CustomGLU, self).__init__()
|
| 35 |
-
self.dim = dim
|
| 36 |
-
self.activation = activation
|
| 37 |
-
|
| 38 |
-
def forward(self, x: Tensor):
|
| 39 |
-
assert x.shape[self.dim] % 2 == 0 # M = N / 2
|
| 40 |
-
a, b = torch.chunk(x, 2, dim=self.dim)
|
| 41 |
-
return a * self.activation(b)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
class SwiGLU(CustomGLU):
|
| 45 |
-
"""SiLU Gated Linear Unit activation.
|
| 46 |
-
Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
|
| 47 |
-
the first half of the input matrices, :math:`b` is the second half.
|
| 48 |
-
|
| 49 |
-
Args:
|
| 50 |
-
dim (int): the dimension on which to split the input. Default: -1
|
| 51 |
-
"""
|
| 52 |
-
def __init__(self, dim: int = -1):
|
| 53 |
-
super(SwiGLU, self).__init__(nn.SiLU(), dim)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
class GeGLU(CustomGLU):
|
| 57 |
-
"""GeLU Gated Linear Unit activation.
|
| 58 |
-
Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
|
| 59 |
-
the first half of the input matrices, :math:`b` is the second half.
|
| 60 |
-
|
| 61 |
-
Args:
|
| 62 |
-
dim (int): the dimension on which to split the input. Default: -1
|
| 63 |
-
"""
|
| 64 |
-
def __init__(self, dim: int = -1):
|
| 65 |
-
super(GeGLU, self).__init__(nn.GELU(), dim)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
class ReGLU(CustomGLU):
|
| 69 |
-
"""ReLU Gated Linear Unit activation.
|
| 70 |
-
Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
|
| 71 |
-
the first half of the input matrices, :math:`b` is the second half.
|
| 72 |
-
|
| 73 |
-
Args:
|
| 74 |
-
dim (int): the dimension on which to split the input. Default: -1
|
| 75 |
-
"""
|
| 76 |
-
def __init__(self, dim: int = -1):
|
| 77 |
-
super(ReGLU, self).__init__(nn.ReLU(), dim)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def get_activation_fn(
|
| 81 |
-
activation: Union[str, Callable[[Tensor], Tensor]]
|
| 82 |
-
) -> Union[str, Callable[[Tensor], Tensor]]:
|
| 83 |
-
"""Helper function to map an activation string to the activation class.
|
| 84 |
-
If the supplied activation is not a string that is recognized, the activation is passed back.
|
| 85 |
-
|
| 86 |
-
Args:
|
| 87 |
-
activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check
|
| 88 |
-
"""
|
| 89 |
-
if isinstance(activation, str):
|
| 90 |
-
if activation == "reglu":
|
| 91 |
-
return ReGLU()
|
| 92 |
-
elif activation == "geglu":
|
| 93 |
-
return GeGLU()
|
| 94 |
-
elif activation == "swiglu":
|
| 95 |
-
return SwiGLU()
|
| 96 |
-
return activation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/chroma.py
DELETED
|
@@ -1,66 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
import typing as tp
|
| 7 |
-
|
| 8 |
-
from einops import rearrange
|
| 9 |
-
from librosa import filters
|
| 10 |
-
import torch
|
| 11 |
-
from torch import nn
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
-
import torchaudio
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class ChromaExtractor(nn.Module):
|
| 17 |
-
"""Chroma extraction and quantization.
|
| 18 |
-
|
| 19 |
-
Args:
|
| 20 |
-
sample_rate (int): Sample rate for the chroma extraction.
|
| 21 |
-
n_chroma (int): Number of chroma bins for the chroma extraction.
|
| 22 |
-
radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
|
| 23 |
-
nfft (int, optional): Number of FFT.
|
| 24 |
-
winlen (int, optional): Window length.
|
| 25 |
-
winhop (int, optional): Window hop size.
|
| 26 |
-
argmax (bool, optional): Whether to use argmax. Defaults to False.
|
| 27 |
-
norm (float, optional): Norm for chroma normalization. Defaults to inf.
|
| 28 |
-
"""
|
| 29 |
-
def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
|
| 30 |
-
winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
|
| 31 |
-
norm: float = torch.inf):
|
| 32 |
-
super().__init__()
|
| 33 |
-
self.winlen = winlen or 2 ** radix2_exp
|
| 34 |
-
self.nfft = nfft or self.winlen
|
| 35 |
-
self.winhop = winhop or (self.winlen // 4)
|
| 36 |
-
self.sample_rate = sample_rate
|
| 37 |
-
self.n_chroma = n_chroma
|
| 38 |
-
self.norm = norm
|
| 39 |
-
self.argmax = argmax
|
| 40 |
-
self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
|
| 41 |
-
n_chroma=self.n_chroma)), persistent=False)
|
| 42 |
-
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
|
| 43 |
-
hop_length=self.winhop, power=2, center=True,
|
| 44 |
-
pad=0, normalized=True)
|
| 45 |
-
|
| 46 |
-
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
| 47 |
-
T = wav.shape[-1]
|
| 48 |
-
# in case we are getting a wav that was dropped out (nullified)
|
| 49 |
-
# from the conditioner, make sure wav length is no less that nfft
|
| 50 |
-
if T < self.nfft:
|
| 51 |
-
pad = self.nfft - T
|
| 52 |
-
r = 0 if pad % 2 == 0 else 1
|
| 53 |
-
wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
|
| 54 |
-
assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
|
| 55 |
-
|
| 56 |
-
spec = self.spec(wav).squeeze(1)
|
| 57 |
-
raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
|
| 58 |
-
norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
|
| 59 |
-
norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
|
| 60 |
-
|
| 61 |
-
if self.argmax:
|
| 62 |
-
idx = norm_chroma.argmax(-1, keepdim=True)
|
| 63 |
-
norm_chroma[:] = 0
|
| 64 |
-
norm_chroma.scatter_(dim=-1, index=idx, value=1)
|
| 65 |
-
|
| 66 |
-
return norm_chroma
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/codebooks_patterns.py
DELETED
|
@@ -1,548 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from collections import namedtuple
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
-
from functools import lru_cache
|
| 10 |
-
import logging
|
| 11 |
-
import typing as tp
|
| 12 |
-
|
| 13 |
-
from abc import ABC, abstractmethod
|
| 14 |
-
import torch
|
| 15 |
-
|
| 16 |
-
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
|
| 17 |
-
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
|
| 18 |
-
logger = logging.getLogger(__name__)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
@dataclass
|
| 22 |
-
class Pattern:
|
| 23 |
-
"""Base implementation of a pattern over a sequence with multiple codebooks.
|
| 24 |
-
|
| 25 |
-
The codebook pattern consists in a layout, defining for each sequence step
|
| 26 |
-
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
|
| 27 |
-
The first item of the pattern is always an empty list in order to properly insert a special token
|
| 28 |
-
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
|
| 29 |
-
and ``timesteps`` the number of timesteps corresponding to the original sequence.
|
| 30 |
-
|
| 31 |
-
The pattern provides convenient methods to build and revert interleaved sequences from it:
|
| 32 |
-
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
|
| 33 |
-
to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
|
| 34 |
-
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
|
| 35 |
-
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
|
| 36 |
-
is returned along with a mask indicating valid tokens.
|
| 37 |
-
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
|
| 38 |
-
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
|
| 39 |
-
to fill and specify invalid positions if needed.
|
| 40 |
-
See the dedicated methods for more details.
|
| 41 |
-
"""
|
| 42 |
-
# Pattern layout, for each sequence step, we have a list of coordinates
|
| 43 |
-
# corresponding to the original codebook timestep and position.
|
| 44 |
-
# The first list is always an empty list in order to properly insert
|
| 45 |
-
# a special token to start with.
|
| 46 |
-
layout: PatternLayout
|
| 47 |
-
timesteps: int
|
| 48 |
-
n_q: int
|
| 49 |
-
|
| 50 |
-
def __post_init__(self):
|
| 51 |
-
assert len(self.layout) > 0
|
| 52 |
-
self._validate_layout()
|
| 53 |
-
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
|
| 54 |
-
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
|
| 55 |
-
logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
|
| 56 |
-
|
| 57 |
-
def _validate_layout(self):
|
| 58 |
-
"""Runs checks on the layout to ensure a valid pattern is defined.
|
| 59 |
-
A pattern is considered invalid if:
|
| 60 |
-
- Multiple timesteps for a same codebook are defined in the same sequence step
|
| 61 |
-
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
|
| 62 |
-
(this would mean that we have future timesteps before past timesteps).
|
| 63 |
-
"""
|
| 64 |
-
q_timesteps = {q: 0 for q in range(self.n_q)}
|
| 65 |
-
for s, seq_coords in enumerate(self.layout):
|
| 66 |
-
if len(seq_coords) > 0:
|
| 67 |
-
qs = set()
|
| 68 |
-
for coord in seq_coords:
|
| 69 |
-
qs.add(coord.q)
|
| 70 |
-
last_q_timestep = q_timesteps[coord.q]
|
| 71 |
-
assert coord.t >= last_q_timestep, \
|
| 72 |
-
f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
|
| 73 |
-
q_timesteps[coord.q] = coord.t
|
| 74 |
-
# each sequence step contains at max 1 coordinate per codebook
|
| 75 |
-
assert len(qs) == len(seq_coords), \
|
| 76 |
-
f"Multiple entries for a same codebook are found at step {s}"
|
| 77 |
-
|
| 78 |
-
@property
|
| 79 |
-
def num_sequence_steps(self):
|
| 80 |
-
return len(self.layout) - 1
|
| 81 |
-
|
| 82 |
-
@property
|
| 83 |
-
def max_delay(self):
|
| 84 |
-
max_t_in_seq_coords = 0
|
| 85 |
-
for seq_coords in self.layout[1:]:
|
| 86 |
-
for coords in seq_coords:
|
| 87 |
-
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
|
| 88 |
-
return max_t_in_seq_coords - self.timesteps
|
| 89 |
-
|
| 90 |
-
@property
|
| 91 |
-
def valid_layout(self):
|
| 92 |
-
valid_step = len(self.layout) - self.max_delay
|
| 93 |
-
return self.layout[:valid_step]
|
| 94 |
-
|
| 95 |
-
def starts_with_special_token(self):
|
| 96 |
-
return self.layout[0] == []
|
| 97 |
-
|
| 98 |
-
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
|
| 99 |
-
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
|
| 100 |
-
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
|
| 101 |
-
and the actual codebook coordinates.
|
| 102 |
-
"""
|
| 103 |
-
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
|
| 104 |
-
if q is not None:
|
| 105 |
-
assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
|
| 106 |
-
coords = []
|
| 107 |
-
for s, seq_codes in enumerate(self.layout):
|
| 108 |
-
for code in seq_codes:
|
| 109 |
-
if code.t == t and (q is None or code.q == q):
|
| 110 |
-
coords.append((s, code))
|
| 111 |
-
return coords
|
| 112 |
-
|
| 113 |
-
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
|
| 114 |
-
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
|
| 115 |
-
|
| 116 |
-
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
|
| 117 |
-
steps_with_timesteps = self.get_steps_with_timestep(t, q)
|
| 118 |
-
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
|
| 119 |
-
|
| 120 |
-
def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
|
| 121 |
-
device: tp.Union[torch.device, str] = 'cpu'):
|
| 122 |
-
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
|
| 123 |
-
|
| 124 |
-
Args:
|
| 125 |
-
timesteps (int): Maximum number of timesteps steps to consider.
|
| 126 |
-
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
|
| 127 |
-
device (torch.device or str): Device for created tensors.
|
| 128 |
-
Returns:
|
| 129 |
-
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
|
| 130 |
-
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
|
| 131 |
-
"""
|
| 132 |
-
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
| 133 |
-
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
|
| 134 |
-
# use the proper layout based on whether we limit ourselves to valid steps only or not,
|
| 135 |
-
# note that using the valid_layout will result in a truncated sequence up to the valid steps
|
| 136 |
-
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
| 137 |
-
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
| 138 |
-
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
|
| 139 |
-
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
|
| 140 |
-
# fill indexes with last sequence step value that will correspond to our special token
|
| 141 |
-
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
|
| 142 |
-
# which will correspond to the index: n_q * timesteps
|
| 143 |
-
indexes[:] = n_q * timesteps
|
| 144 |
-
# iterate over the pattern and fill scattered indexes and mask
|
| 145 |
-
for s, sequence_coords in enumerate(ref_layout):
|
| 146 |
-
for coords in sequence_coords:
|
| 147 |
-
if coords.t < timesteps:
|
| 148 |
-
indexes[coords.q, s] = coords.t + coords.q * timesteps
|
| 149 |
-
mask[coords.q, s] = 1
|
| 150 |
-
indexes = torch.from_numpy(indexes).to(device)
|
| 151 |
-
mask = torch.from_numpy(mask).to(device)
|
| 152 |
-
return indexes, mask
|
| 153 |
-
|
| 154 |
-
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
| 155 |
-
"""Build sequence corresponding to the pattern from the input tensor z.
|
| 156 |
-
The sequence is built using up to sequence_steps if specified, and non-pattern
|
| 157 |
-
coordinates are filled with the special token.
|
| 158 |
-
|
| 159 |
-
Args:
|
| 160 |
-
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
|
| 161 |
-
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
|
| 162 |
-
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
| 163 |
-
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
| 164 |
-
Returns:
|
| 165 |
-
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
|
| 166 |
-
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
|
| 167 |
-
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
|
| 168 |
-
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
|
| 169 |
-
"""
|
| 170 |
-
B, K, T = z.shape
|
| 171 |
-
indexes, mask = self._build_pattern_sequence_scatter_indexes(
|
| 172 |
-
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
|
| 173 |
-
)
|
| 174 |
-
z = z.view(B, -1)
|
| 175 |
-
# we append the special token as the last index of our flattened z tensor
|
| 176 |
-
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
|
| 177 |
-
values = z[:, indexes.view(-1)]
|
| 178 |
-
values = values.view(B, K, indexes.shape[-1])
|
| 179 |
-
return values, indexes, mask
|
| 180 |
-
|
| 181 |
-
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
|
| 182 |
-
keep_only_valid_steps: bool = False,
|
| 183 |
-
is_model_output: bool = False,
|
| 184 |
-
device: tp.Union[torch.device, str] = 'cpu'):
|
| 185 |
-
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
|
| 186 |
-
from interleaving pattern.
|
| 187 |
-
|
| 188 |
-
Args:
|
| 189 |
-
sequence_steps (int): Sequence steps.
|
| 190 |
-
n_q (int): Number of codebooks.
|
| 191 |
-
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
| 192 |
-
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
| 193 |
-
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
|
| 194 |
-
device (torch.device or str): Device for created tensors.
|
| 195 |
-
Returns:
|
| 196 |
-
indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
|
| 197 |
-
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
| 198 |
-
"""
|
| 199 |
-
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
| 200 |
-
# TODO(jade): Do we want to further truncate to only valid timesteps here as well?
|
| 201 |
-
timesteps = self.timesteps
|
| 202 |
-
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
| 203 |
-
assert sequence_steps <= len(ref_layout), \
|
| 204 |
-
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
|
| 205 |
-
|
| 206 |
-
# ensure we take the appropriate indexes to keep the model output from the first special token as well
|
| 207 |
-
if is_model_output and self.starts_with_special_token():
|
| 208 |
-
ref_layout = ref_layout[1:]
|
| 209 |
-
|
| 210 |
-
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
| 211 |
-
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
|
| 212 |
-
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
|
| 213 |
-
# fill indexes with last sequence step value that will correspond to our special token
|
| 214 |
-
indexes[:] = n_q * sequence_steps
|
| 215 |
-
for s, sequence_codes in enumerate(ref_layout):
|
| 216 |
-
if s < sequence_steps:
|
| 217 |
-
for code in sequence_codes:
|
| 218 |
-
if code.t < timesteps:
|
| 219 |
-
indexes[code.q, code.t] = s + code.q * sequence_steps
|
| 220 |
-
mask[code.q, code.t] = 1
|
| 221 |
-
indexes = torch.from_numpy(indexes).to(device)
|
| 222 |
-
mask = torch.from_numpy(mask).to(device)
|
| 223 |
-
return indexes, mask
|
| 224 |
-
|
| 225 |
-
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
| 226 |
-
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
|
| 227 |
-
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
|
| 228 |
-
are filled with the special token.
|
| 229 |
-
|
| 230 |
-
Args:
|
| 231 |
-
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
|
| 232 |
-
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
|
| 233 |
-
Returns:
|
| 234 |
-
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
|
| 235 |
-
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
|
| 236 |
-
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
|
| 237 |
-
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
| 238 |
-
"""
|
| 239 |
-
B, K, S = s.shape
|
| 240 |
-
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
| 241 |
-
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
|
| 242 |
-
)
|
| 243 |
-
s = s.view(B, -1)
|
| 244 |
-
# we append the special token as the last index of our flattened z tensor
|
| 245 |
-
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
|
| 246 |
-
values = s[:, indexes.view(-1)]
|
| 247 |
-
values = values.view(B, K, indexes.shape[-1])
|
| 248 |
-
return values, indexes, mask
|
| 249 |
-
|
| 250 |
-
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
|
| 251 |
-
"""Revert model logits obtained on a sequence built from the pattern
|
| 252 |
-
back to a tensor matching the original sequence.
|
| 253 |
-
|
| 254 |
-
This method is similar to ``revert_pattern_sequence`` with the following specificities:
|
| 255 |
-
1. It is designed to work with the extra cardinality dimension
|
| 256 |
-
2. We return the logits for the first sequence item that matches the special_token and
|
| 257 |
-
which matching target in the original sequence is the first item of the sequence,
|
| 258 |
-
while we skip the last logits as there is no matching target
|
| 259 |
-
"""
|
| 260 |
-
B, card, K, S = logits.shape
|
| 261 |
-
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
| 262 |
-
S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
|
| 263 |
-
)
|
| 264 |
-
logits = logits.reshape(B, card, -1)
|
| 265 |
-
# we append the special token as the last index of our flattened z tensor
|
| 266 |
-
logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
|
| 267 |
-
values = logits[:, :, indexes.view(-1)]
|
| 268 |
-
values = values.view(B, card, K, indexes.shape[-1])
|
| 269 |
-
return values, indexes, mask
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
class CodebooksPatternProvider(ABC):
|
| 273 |
-
"""Abstraction around providing pattern for interleaving codebooks.
|
| 274 |
-
|
| 275 |
-
The CodebooksPatternProvider abstraction allows to implement various strategies to
|
| 276 |
-
define interleaving pattern of sequences composed of multiple codebooks. For a given
|
| 277 |
-
number of codebooks `n_q`, the pattern provider can generate a specified pattern
|
| 278 |
-
corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
|
| 279 |
-
can be used to construct a new sequence from the original codes respecting the specified
|
| 280 |
-
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
|
| 281 |
-
being a tuple with the original timestep and codebook to build the new sequence.
|
| 282 |
-
Note that all patterns must start with an empty list that is then used to insert a first
|
| 283 |
-
sequence step of special tokens in the newly generated sequence.
|
| 284 |
-
|
| 285 |
-
Args:
|
| 286 |
-
n_q (int): number of codebooks.
|
| 287 |
-
cached (bool): if True, patterns for a given length are cached. In general
|
| 288 |
-
that should be true for efficiency reason to avoid synchronization points.
|
| 289 |
-
"""
|
| 290 |
-
def __init__(self, n_q: int, cached: bool = True):
|
| 291 |
-
assert n_q > 0
|
| 292 |
-
self.n_q = n_q
|
| 293 |
-
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
|
| 294 |
-
|
| 295 |
-
@abstractmethod
|
| 296 |
-
def get_pattern(self, timesteps: int) -> Pattern:
|
| 297 |
-
"""Builds pattern with specific interleaving between codebooks.
|
| 298 |
-
|
| 299 |
-
Args:
|
| 300 |
-
timesteps (int): Total number of timesteps.
|
| 301 |
-
"""
|
| 302 |
-
raise NotImplementedError()
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
class DelayedPatternProvider(CodebooksPatternProvider):
|
| 306 |
-
"""Provider for delayed pattern across delayed codebooks.
|
| 307 |
-
Codebooks are delayed in the sequence and sequence steps will contain codebooks
|
| 308 |
-
from different timesteps.
|
| 309 |
-
|
| 310 |
-
Example:
|
| 311 |
-
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
|
| 312 |
-
[[1, 2, 3, 4],
|
| 313 |
-
[1, 2, 3, 4],
|
| 314 |
-
[1, 2, 3, 4]]
|
| 315 |
-
The resulting sequence obtained from the returned pattern is:
|
| 316 |
-
[[S, 1, 2, 3, 4],
|
| 317 |
-
[S, S, 1, 2, 3],
|
| 318 |
-
[S, S, S, 1, 2]]
|
| 319 |
-
(with S being a special token)
|
| 320 |
-
|
| 321 |
-
Args:
|
| 322 |
-
n_q (int): Number of codebooks.
|
| 323 |
-
delays (list of int, optional): Delay for each of the codebooks.
|
| 324 |
-
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
| 325 |
-
flatten_first (int): Flatten the first N timesteps.
|
| 326 |
-
empty_initial (int): Prepend with N empty list of coordinates.
|
| 327 |
-
"""
|
| 328 |
-
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
|
| 329 |
-
flatten_first: int = 0, empty_initial: int = 0):
|
| 330 |
-
super().__init__(n_q)
|
| 331 |
-
if delays is None:
|
| 332 |
-
delays = list(range(n_q))
|
| 333 |
-
self.delays = delays
|
| 334 |
-
self.flatten_first = flatten_first
|
| 335 |
-
self.empty_initial = empty_initial
|
| 336 |
-
assert len(self.delays) == self.n_q
|
| 337 |
-
assert sorted(self.delays) == self.delays
|
| 338 |
-
|
| 339 |
-
def get_pattern(self, timesteps: int) -> Pattern:
|
| 340 |
-
omit_special_token = self.empty_initial < 0
|
| 341 |
-
out: PatternLayout = [] if omit_special_token else [[]]
|
| 342 |
-
max_delay = max(self.delays)
|
| 343 |
-
if self.empty_initial:
|
| 344 |
-
out += [[] for _ in range(self.empty_initial)]
|
| 345 |
-
if self.flatten_first:
|
| 346 |
-
for t in range(min(timesteps, self.flatten_first)):
|
| 347 |
-
for q in range(self.n_q):
|
| 348 |
-
out.append([LayoutCoord(t, q)])
|
| 349 |
-
for t in range(self.flatten_first, timesteps + max_delay):
|
| 350 |
-
v = []
|
| 351 |
-
for q, delay in enumerate(self.delays):
|
| 352 |
-
t_for_q = t - delay
|
| 353 |
-
if t_for_q >= self.flatten_first:
|
| 354 |
-
v.append(LayoutCoord(t_for_q, q))
|
| 355 |
-
out.append(v)
|
| 356 |
-
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
class ParallelPatternProvider(DelayedPatternProvider):
|
| 360 |
-
"""Provider for parallel pattern across codebooks.
|
| 361 |
-
This pattern provider is a special case of the delayed pattern with actually no delay,
|
| 362 |
-
hence delays=repeat(0, n_q).
|
| 363 |
-
|
| 364 |
-
Args:
|
| 365 |
-
n_q (int): Number of codebooks.
|
| 366 |
-
empty_initial (int): Prepend with N empty list of coordinates.
|
| 367 |
-
"""
|
| 368 |
-
def __init__(self, n_q: int, empty_initial: int = 0):
|
| 369 |
-
super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
class UnrolledPatternProvider(CodebooksPatternProvider):
|
| 373 |
-
"""Provider for unrolling codebooks pattern.
|
| 374 |
-
This pattern provider enables to represent the codebook flattened completely or only to some extend
|
| 375 |
-
while also specifying a given delay between the flattened codebooks representation, allowing to
|
| 376 |
-
unroll the codebooks in the sequence.
|
| 377 |
-
|
| 378 |
-
Example:
|
| 379 |
-
1. Flattening of the codebooks.
|
| 380 |
-
By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
|
| 381 |
-
taking n_q = 3 and timesteps = 4:
|
| 382 |
-
[[1, 2, 3, 4],
|
| 383 |
-
[1, 2, 3, 4],
|
| 384 |
-
[1, 2, 3, 4]]
|
| 385 |
-
will result into:
|
| 386 |
-
[[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
|
| 387 |
-
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
| 388 |
-
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
| 389 |
-
2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
|
| 390 |
-
for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
|
| 391 |
-
taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
|
| 392 |
-
[[1, 2, 3, 4],
|
| 393 |
-
[1, 2, 3, 4],
|
| 394 |
-
[1, 2, 3, 4]]
|
| 395 |
-
will result into:
|
| 396 |
-
[[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
| 397 |
-
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
| 398 |
-
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
| 399 |
-
3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
|
| 400 |
-
allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
|
| 401 |
-
same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
|
| 402 |
-
and delays = [0, 3, 3]:
|
| 403 |
-
[[1, 2, 3, 4],
|
| 404 |
-
[1, 2, 3, 4],
|
| 405 |
-
[1, 2, 3, 4]]
|
| 406 |
-
will result into:
|
| 407 |
-
[[S, S, S, 1, S, 2, S, 3, S, 4],
|
| 408 |
-
[S, S, S, 1, S, 2, S, 3, S, 4],
|
| 409 |
-
[1, 2, 3, S, 4, S, 5, S, 6, S]]
|
| 410 |
-
|
| 411 |
-
Args:
|
| 412 |
-
n_q (int): Number of codebooks.
|
| 413 |
-
flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
|
| 414 |
-
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
|
| 415 |
-
have n_q extra steps for each timestep.
|
| 416 |
-
delays (list of int, optional): Delay for each of the codebooks. If not defined,
|
| 417 |
-
no delay is added and therefore will default to [0] * ``n_q``.
|
| 418 |
-
Note that two codebooks that will be flattened to the same inner step
|
| 419 |
-
should have the same delay, otherwise the pattern is considered as invalid.
|
| 420 |
-
"""
|
| 421 |
-
FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
|
| 422 |
-
|
| 423 |
-
def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
|
| 424 |
-
delays: tp.Optional[tp.List[int]] = None):
|
| 425 |
-
super().__init__(n_q)
|
| 426 |
-
if flattening is None:
|
| 427 |
-
flattening = list(range(n_q))
|
| 428 |
-
if delays is None:
|
| 429 |
-
delays = [0] * n_q
|
| 430 |
-
assert len(flattening) == n_q
|
| 431 |
-
assert len(delays) == n_q
|
| 432 |
-
assert sorted(flattening) == flattening
|
| 433 |
-
assert sorted(delays) == delays
|
| 434 |
-
self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
|
| 435 |
-
self.max_delay = max(delays)
|
| 436 |
-
|
| 437 |
-
def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
|
| 438 |
-
"""Build a flattened codebooks representation as a dictionary of inner step
|
| 439 |
-
and the actual codebook indices corresponding to the flattened codebook. For convenience, we
|
| 440 |
-
also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
|
| 441 |
-
"""
|
| 442 |
-
flattened_codebooks: dict = {}
|
| 443 |
-
for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
|
| 444 |
-
if inner_step not in flattened_codebooks:
|
| 445 |
-
flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
|
| 446 |
-
else:
|
| 447 |
-
flat_codebook = flattened_codebooks[inner_step]
|
| 448 |
-
assert flat_codebook.delay == delay, (
|
| 449 |
-
"Delay and flattening between codebooks is inconsistent: ",
|
| 450 |
-
"two codebooks flattened to the same position should have the same delay."
|
| 451 |
-
)
|
| 452 |
-
flat_codebook.codebooks.append(q)
|
| 453 |
-
flattened_codebooks[inner_step] = flat_codebook
|
| 454 |
-
return flattened_codebooks
|
| 455 |
-
|
| 456 |
-
@property
|
| 457 |
-
def _num_inner_steps(self):
|
| 458 |
-
"""Number of inner steps to unroll between timesteps in order to flatten the codebooks.
|
| 459 |
-
"""
|
| 460 |
-
return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
|
| 461 |
-
|
| 462 |
-
def num_virtual_steps(self, timesteps: int) -> int:
|
| 463 |
-
return timesteps * self._num_inner_steps + 1
|
| 464 |
-
|
| 465 |
-
def get_pattern(self, timesteps: int) -> Pattern:
|
| 466 |
-
"""Builds pattern for delay across codebooks.
|
| 467 |
-
|
| 468 |
-
Args:
|
| 469 |
-
timesteps (int): Total number of timesteps.
|
| 470 |
-
"""
|
| 471 |
-
# the PatternLayout is built as a tuple of sequence position and list of coordinates
|
| 472 |
-
# so that it can be reordered properly given the required delay between codebooks of given timesteps
|
| 473 |
-
indexed_out: list = [(-1, [])]
|
| 474 |
-
max_timesteps = timesteps + self.max_delay
|
| 475 |
-
for t in range(max_timesteps):
|
| 476 |
-
# for each timestep, we unroll the flattened codebooks,
|
| 477 |
-
# emitting the sequence step with the corresponding delay
|
| 478 |
-
for step in range(self._num_inner_steps):
|
| 479 |
-
if step in self._flattened_codebooks:
|
| 480 |
-
# we have codebooks at this virtual step to emit
|
| 481 |
-
step_codebooks = self._flattened_codebooks[step]
|
| 482 |
-
t_for_q = t + step_codebooks.delay
|
| 483 |
-
coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
|
| 484 |
-
if t_for_q < max_timesteps and t < max_timesteps:
|
| 485 |
-
indexed_out.append((t_for_q, coords))
|
| 486 |
-
else:
|
| 487 |
-
# there is no codebook in this virtual step so we emit an empty list
|
| 488 |
-
indexed_out.append((t, []))
|
| 489 |
-
out = [coords for _, coords in sorted(indexed_out)]
|
| 490 |
-
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
class CoarseFirstPattern(CodebooksPatternProvider):
|
| 494 |
-
"""First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
|
| 495 |
-
potentially with delays.
|
| 496 |
-
|
| 497 |
-
..Warning:: You must always generate the full training duration at test time, for instance,
|
| 498 |
-
30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
|
| 499 |
-
location. This is due to the non causality of the remaining codebooks with respect to
|
| 500 |
-
the first ones.
|
| 501 |
-
|
| 502 |
-
Args:
|
| 503 |
-
n_q (int): Number of codebooks.
|
| 504 |
-
delays (list of int, optional): Delay for each of the codebooks.
|
| 505 |
-
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
| 506 |
-
"""
|
| 507 |
-
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
|
| 508 |
-
super().__init__(n_q)
|
| 509 |
-
if delays is None:
|
| 510 |
-
delays = [0] * (n_q - 1)
|
| 511 |
-
self.delays = delays
|
| 512 |
-
assert len(self.delays) == self.n_q - 1
|
| 513 |
-
assert sorted(self.delays) == self.delays
|
| 514 |
-
|
| 515 |
-
def get_pattern(self, timesteps: int) -> Pattern:
|
| 516 |
-
out: PatternLayout = [[]]
|
| 517 |
-
for t in range(timesteps):
|
| 518 |
-
out.append([LayoutCoord(t, 0)])
|
| 519 |
-
max_delay = max(self.delays)
|
| 520 |
-
for t in range(timesteps + max_delay):
|
| 521 |
-
v = []
|
| 522 |
-
for q, delay in enumerate(self.delays):
|
| 523 |
-
t_for_q = t - delay
|
| 524 |
-
if t_for_q >= 0:
|
| 525 |
-
v.append(LayoutCoord(t_for_q, q + 1))
|
| 526 |
-
out.append(v)
|
| 527 |
-
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
class MusicLMPattern(CodebooksPatternProvider):
|
| 531 |
-
"""Almost MusicLM style pattern. This is equivalent to full flattening
|
| 532 |
-
but in a different order.
|
| 533 |
-
|
| 534 |
-
Args:
|
| 535 |
-
n_q (int): Number of codebooks.
|
| 536 |
-
group_by (int): Number of codebooks to group together.
|
| 537 |
-
"""
|
| 538 |
-
def __init__(self, n_q: int, group_by: int = 2):
|
| 539 |
-
super().__init__(n_q)
|
| 540 |
-
self.group_by = group_by
|
| 541 |
-
|
| 542 |
-
def get_pattern(self, timesteps: int) -> Pattern:
|
| 543 |
-
out: PatternLayout = [[]]
|
| 544 |
-
for offset in range(0, self.n_q, self.group_by):
|
| 545 |
-
for t in range(timesteps):
|
| 546 |
-
for q in range(offset, offset + self.group_by):
|
| 547 |
-
out.append([LayoutCoord(t, q)])
|
| 548 |
-
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/conditioners.py
DELETED
|
@@ -1,1763 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from collections import defaultdict
|
| 8 |
-
from copy import deepcopy
|
| 9 |
-
from dataclasses import dataclass, field
|
| 10 |
-
from itertools import chain
|
| 11 |
-
import logging
|
| 12 |
-
import math
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
import random
|
| 15 |
-
import re
|
| 16 |
-
import typing as tp
|
| 17 |
-
import warnings
|
| 18 |
-
import einops
|
| 19 |
-
import flashy
|
| 20 |
-
from num2words import num2words
|
| 21 |
-
import spacy
|
| 22 |
-
from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
|
| 23 |
-
import torch
|
| 24 |
-
from torch import nn
|
| 25 |
-
import torch.nn.functional as F
|
| 26 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 27 |
-
from enum import Enum
|
| 28 |
-
from .chroma import ChromaExtractor
|
| 29 |
-
from .streaming import StreamingModule
|
| 30 |
-
from .transformer import create_sin_embedding, StreamingTransformer
|
| 31 |
-
from ..data.audio import audio_read
|
| 32 |
-
from ..data.audio_dataset import SegmentInfo
|
| 33 |
-
from ..data.audio_utils import convert_audio
|
| 34 |
-
from ..environment import AudioCraftEnvironment
|
| 35 |
-
from ..quantization import ResidualVectorQuantizer
|
| 36 |
-
from ..utils.autocast import TorchAutocast
|
| 37 |
-
from ..utils.cache import EmbeddingCache
|
| 38 |
-
from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
logger = logging.getLogger(__name__)
|
| 42 |
-
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
| 43 |
-
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class JascoCondConst(Enum):
|
| 47 |
-
DRM = 'self_wav'
|
| 48 |
-
CRD = 'chords'
|
| 49 |
-
MLD = 'melody'
|
| 50 |
-
SYM = {'chords', 'melody'}
|
| 51 |
-
LAT = {'self_wav'}
|
| 52 |
-
ALL = ['chords', 'self_wav', 'melody'] # order matters
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
class WavCondition(tp.NamedTuple):
|
| 56 |
-
wav: torch.Tensor
|
| 57 |
-
length: torch.Tensor
|
| 58 |
-
sample_rate: tp.List[int]
|
| 59 |
-
path: tp.List[tp.Optional[str]] = []
|
| 60 |
-
seek_time: tp.List[tp.Optional[float]] = []
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class JointEmbedCondition(tp.NamedTuple):
|
| 64 |
-
wav: torch.Tensor
|
| 65 |
-
text: tp.List[tp.Optional[str]]
|
| 66 |
-
length: torch.Tensor
|
| 67 |
-
sample_rate: tp.List[int]
|
| 68 |
-
path: tp.List[tp.Optional[str]] = []
|
| 69 |
-
seek_time: tp.List[tp.Optional[float]] = []
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class SymbolicCondition(tp.NamedTuple):
|
| 73 |
-
frame_chords: tp.Optional[torch.Tensor] = None
|
| 74 |
-
melody: tp.Optional[torch.Tensor] = None
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
@dataclass
|
| 78 |
-
class ConditioningAttributes:
|
| 79 |
-
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
| 80 |
-
wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
|
| 81 |
-
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
| 82 |
-
symbolic: tp.Dict[str, SymbolicCondition] = field(default_factory=dict)
|
| 83 |
-
|
| 84 |
-
def __getitem__(self, item):
|
| 85 |
-
return getattr(self, item)
|
| 86 |
-
|
| 87 |
-
@property
|
| 88 |
-
def text_attributes(self):
|
| 89 |
-
return self.text.keys()
|
| 90 |
-
|
| 91 |
-
@property
|
| 92 |
-
def wav_attributes(self):
|
| 93 |
-
return self.wav.keys()
|
| 94 |
-
|
| 95 |
-
@property
|
| 96 |
-
def joint_embed_attributes(self):
|
| 97 |
-
return self.joint_embed.keys()
|
| 98 |
-
|
| 99 |
-
@property
|
| 100 |
-
def symbolic_attributes(self):
|
| 101 |
-
return self.symbolic.keys()
|
| 102 |
-
|
| 103 |
-
@property
|
| 104 |
-
def attributes(self):
|
| 105 |
-
return {
|
| 106 |
-
"text": self.text_attributes,
|
| 107 |
-
"wav": self.wav_attributes,
|
| 108 |
-
"joint_embed": self.joint_embed_attributes,
|
| 109 |
-
"symbolic": self.symbolic_attributes,
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
def to_flat_dict(self):
|
| 113 |
-
return {
|
| 114 |
-
**{f"text.{k}": v for k, v in self.text.items()},
|
| 115 |
-
**{f"wav.{k}": v for k, v in self.wav.items()},
|
| 116 |
-
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()},
|
| 117 |
-
**{f"symbolic.{k}": v for k, v in self.symbolic.items()}
|
| 118 |
-
}
|
| 119 |
-
|
| 120 |
-
@classmethod
|
| 121 |
-
def from_flat_dict(cls, x):
|
| 122 |
-
out = cls()
|
| 123 |
-
for k, v in x.items():
|
| 124 |
-
kind, att = k.split(".")
|
| 125 |
-
out[kind][att] = v
|
| 126 |
-
return out
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
class SegmentWithAttributes(SegmentInfo):
|
| 130 |
-
"""Base class for all dataclasses that are used for conditioning.
|
| 131 |
-
All child classes should implement `to_condition_attributes` that converts
|
| 132 |
-
the existing attributes to a dataclass of type ConditioningAttributes.
|
| 133 |
-
"""
|
| 134 |
-
def to_condition_attributes(self) -> ConditioningAttributes:
|
| 135 |
-
raise NotImplementedError()
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def nullify_condition(condition: ConditionType, dim: int = 1):
|
| 139 |
-
"""Transform an input condition to a null condition.
|
| 140 |
-
The way it is done by converting it to a single zero vector similarly
|
| 141 |
-
to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
|
| 142 |
-
|
| 143 |
-
Args:
|
| 144 |
-
condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
|
| 145 |
-
dim (int): The dimension that will be truncated (should be the time dimension)
|
| 146 |
-
WARNING!: dim should not be the batch dimension!
|
| 147 |
-
Returns:
|
| 148 |
-
ConditionType: A tuple of null condition and mask
|
| 149 |
-
"""
|
| 150 |
-
assert dim != 0, "dim cannot be the batch dimension!"
|
| 151 |
-
assert isinstance(condition, tuple) and \
|
| 152 |
-
isinstance(condition[0], torch.Tensor) and \
|
| 153 |
-
isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
|
| 154 |
-
cond, mask = condition
|
| 155 |
-
B = cond.shape[0]
|
| 156 |
-
last_dim = cond.dim() - 1
|
| 157 |
-
out = cond.transpose(dim, last_dim)
|
| 158 |
-
out = 0. * out[..., :1]
|
| 159 |
-
out = out.transpose(dim, last_dim)
|
| 160 |
-
mask = torch.zeros((B, 1), device=out.device).int()
|
| 161 |
-
assert cond.dim() == out.dim()
|
| 162 |
-
return out, mask
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
def nullify_wav(cond: WavCondition) -> WavCondition:
|
| 166 |
-
"""Transform a WavCondition to a nullified WavCondition.
|
| 167 |
-
It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
|
| 168 |
-
|
| 169 |
-
Args:
|
| 170 |
-
cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
|
| 171 |
-
Returns:
|
| 172 |
-
WavCondition: Nullified wav condition.
|
| 173 |
-
"""
|
| 174 |
-
null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
|
| 175 |
-
return WavCondition(
|
| 176 |
-
wav=null_wav,
|
| 177 |
-
length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
|
| 178 |
-
sample_rate=cond.sample_rate,
|
| 179 |
-
path=[None] * cond.wav.shape[0],
|
| 180 |
-
seek_time=[None] * cond.wav.shape[0],
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
|
| 185 |
-
"""Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
|
| 186 |
-
and replacing metadata by dummy attributes.
|
| 187 |
-
|
| 188 |
-
Args:
|
| 189 |
-
cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
|
| 190 |
-
"""
|
| 191 |
-
null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
|
| 192 |
-
return JointEmbedCondition(
|
| 193 |
-
wav=null_wav, text=[None] * len(embed.text),
|
| 194 |
-
length=torch.LongTensor([0]).to(embed.wav.device),
|
| 195 |
-
sample_rate=embed.sample_rate,
|
| 196 |
-
path=[None] * embed.wav.shape[0],
|
| 197 |
-
seek_time=[0] * embed.wav.shape[0],
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def nullify_chords(sym_cond: SymbolicCondition, null_chord_idx: int = 194) -> SymbolicCondition:
|
| 202 |
-
"""Nullify the symbolic condition by setting all frame chords to a specified null chord index.
|
| 203 |
-
Args:
|
| 204 |
-
sym_cond (SymbolicCondition): The symbolic condition containing frame chords to be nullified.
|
| 205 |
-
null_chord_idx (int, optional): The index to use for nullifying the chords. Defaults to 194 (Chordino).
|
| 206 |
-
Returns:
|
| 207 |
-
SymbolicCondition: A new symbolic condition with all frame chords set to the null chord index.
|
| 208 |
-
"""
|
| 209 |
-
return SymbolicCondition(frame_chords=torch.ones_like(sym_cond.frame_chords) * null_chord_idx) # type: ignore
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
def nullify_melody(sym_cond: SymbolicCondition) -> SymbolicCondition:
|
| 213 |
-
"""Nullify the symbolic condition by replacing the melody matrix with zeros matrix.
|
| 214 |
-
Args:
|
| 215 |
-
sym_cond (SymbolicCondition): The symbolic condition containing frame chords to be nullified.
|
| 216 |
-
null_chord_idx (int, optional): The index to use for nullifying the chords. Defaults to 194 (Chordino).
|
| 217 |
-
Returns:
|
| 218 |
-
SymbolicCondition: A new symbolic condition with all frame chords set to the null chord index.
|
| 219 |
-
"""
|
| 220 |
-
return SymbolicCondition(melody=torch.zeros_like(sym_cond.melody)) # type: ignore
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
def _drop_description_condition(conditions: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
| 224 |
-
"""Drop the text condition but keep the wav conditon on a list of ConditioningAttributes.
|
| 225 |
-
This is useful to calculate l_style in the double classifier free guidance formula.
|
| 226 |
-
See paragraph 4.3 in https://arxiv.org/pdf/2407.12563
|
| 227 |
-
|
| 228 |
-
Args:
|
| 229 |
-
conditions (tp.List[ConditioningAttributes]): List of conditions.
|
| 230 |
-
"""
|
| 231 |
-
# We assert that description and self_wav are in the conditions
|
| 232 |
-
for condition in conditions:
|
| 233 |
-
assert 'description' in condition.text.keys()
|
| 234 |
-
assert 'self_wav' in condition.wav.keys()
|
| 235 |
-
return AttributeDropout(p={'text': {'description': 1.0},
|
| 236 |
-
'wav': {'self_wav': 0.0}})(conditions)
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
class Tokenizer:
|
| 240 |
-
"""Base tokenizer implementation
|
| 241 |
-
(in case we want to introduce more advances tokenizers in the future).
|
| 242 |
-
"""
|
| 243 |
-
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 244 |
-
raise NotImplementedError()
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
class WhiteSpaceTokenizer(Tokenizer):
|
| 248 |
-
"""This tokenizer should be used for natural language descriptions.
|
| 249 |
-
For example:
|
| 250 |
-
["he didn't, know he's going home.", 'shorter sentence'] =>
|
| 251 |
-
[[78, 62, 31, 4, 78, 25, 19, 34],
|
| 252 |
-
[59, 77, 0, 0, 0, 0, 0, 0]]
|
| 253 |
-
"""
|
| 254 |
-
PUNCTUATION = "?:!.,;"
|
| 255 |
-
|
| 256 |
-
def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
|
| 257 |
-
lemma: bool = True, stopwords: bool = True) -> None:
|
| 258 |
-
self.n_bins = n_bins
|
| 259 |
-
self.pad_idx = pad_idx
|
| 260 |
-
self.lemma = lemma
|
| 261 |
-
self.stopwords = stopwords
|
| 262 |
-
try:
|
| 263 |
-
self.nlp = spacy.load(language)
|
| 264 |
-
except IOError:
|
| 265 |
-
spacy.cli.download(language) # type: ignore
|
| 266 |
-
self.nlp = spacy.load(language)
|
| 267 |
-
|
| 268 |
-
@tp.no_type_check
|
| 269 |
-
def __call__(self, texts: tp.List[tp.Optional[str]],
|
| 270 |
-
return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 271 |
-
"""Take a list of strings and convert them to a tensor of indices.
|
| 272 |
-
|
| 273 |
-
Args:
|
| 274 |
-
texts (list[str]): List of strings.
|
| 275 |
-
return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
|
| 276 |
-
Returns:
|
| 277 |
-
tuple[torch.Tensor, torch.Tensor]:
|
| 278 |
-
- Indices of words in the LUT.
|
| 279 |
-
- And a mask indicating where the padding tokens are
|
| 280 |
-
"""
|
| 281 |
-
output, lengths = [], []
|
| 282 |
-
texts = deepcopy(texts)
|
| 283 |
-
for i, text in enumerate(texts):
|
| 284 |
-
# if current sample doesn't have a certain attribute, replace with pad token
|
| 285 |
-
if text is None:
|
| 286 |
-
output.append(torch.Tensor([self.pad_idx]))
|
| 287 |
-
lengths.append(0)
|
| 288 |
-
continue
|
| 289 |
-
|
| 290 |
-
# convert numbers to words
|
| 291 |
-
text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
|
| 292 |
-
# normalize text
|
| 293 |
-
text = self.nlp(text) # type: ignore
|
| 294 |
-
# remove stopwords
|
| 295 |
-
if self.stopwords:
|
| 296 |
-
text = [w for w in text if not w.is_stop] # type: ignore
|
| 297 |
-
# remove punctuation
|
| 298 |
-
text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
|
| 299 |
-
# lemmatize if needed
|
| 300 |
-
text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
|
| 301 |
-
|
| 302 |
-
texts[i] = " ".join(text)
|
| 303 |
-
lengths.append(len(text))
|
| 304 |
-
# convert to tensor
|
| 305 |
-
tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
|
| 306 |
-
output.append(tokens)
|
| 307 |
-
|
| 308 |
-
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
| 309 |
-
padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
|
| 310 |
-
if return_text:
|
| 311 |
-
return padded_output, mask, texts # type: ignore
|
| 312 |
-
return padded_output, mask
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
class NoopTokenizer(Tokenizer):
|
| 316 |
-
"""This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
|
| 317 |
-
The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
|
| 318 |
-
strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
|
| 319 |
-
split it to ["Jeff", "Buckley"] and return an index per word.
|
| 320 |
-
|
| 321 |
-
For example:
|
| 322 |
-
["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
|
| 323 |
-
["Metal", "Rock", "Classical"] => [0, 223, 51]
|
| 324 |
-
"""
|
| 325 |
-
def __init__(self, n_bins: int, pad_idx: int = 0):
|
| 326 |
-
self.n_bins = n_bins
|
| 327 |
-
self.pad_idx = pad_idx
|
| 328 |
-
|
| 329 |
-
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 330 |
-
output, lengths = [], []
|
| 331 |
-
for text in texts:
|
| 332 |
-
# if current sample doesn't have a certain attribute, replace with pad token
|
| 333 |
-
if text is None:
|
| 334 |
-
output.append(self.pad_idx)
|
| 335 |
-
lengths.append(0)
|
| 336 |
-
else:
|
| 337 |
-
output.append(hash_trick(text, self.n_bins))
|
| 338 |
-
lengths.append(1)
|
| 339 |
-
|
| 340 |
-
tokens = torch.LongTensor(output).unsqueeze(1)
|
| 341 |
-
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
| 342 |
-
return tokens, mask
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
class BaseConditioner(nn.Module):
|
| 346 |
-
"""Base model for all conditioner modules.
|
| 347 |
-
We allow the output dim to be different than the hidden dim for two reasons:
|
| 348 |
-
1) keep our LUTs small when the vocab is large;
|
| 349 |
-
2) make all condition dims consistent.
|
| 350 |
-
|
| 351 |
-
Args:
|
| 352 |
-
dim (int): Hidden dim of the model.
|
| 353 |
-
output_dim (int): Output dim of the conditioner.
|
| 354 |
-
"""
|
| 355 |
-
def __init__(self, dim: int, output_dim: int):
|
| 356 |
-
super().__init__()
|
| 357 |
-
self.dim = dim
|
| 358 |
-
self.output_dim = output_dim
|
| 359 |
-
if self.output_dim > -1: # omit projection when output_dim <= 0
|
| 360 |
-
self.output_proj = nn.Linear(dim, output_dim)
|
| 361 |
-
|
| 362 |
-
def tokenize(self, *args, **kwargs) -> tp.Any:
|
| 363 |
-
"""Should be any part of the processing that will lead to a synchronization
|
| 364 |
-
point, e.g. BPE tokenization with transfer to the GPU.
|
| 365 |
-
|
| 366 |
-
The returned value will be saved and return later when calling forward().
|
| 367 |
-
"""
|
| 368 |
-
raise NotImplementedError()
|
| 369 |
-
|
| 370 |
-
def forward(self, inputs: tp.Any) -> ConditionType:
|
| 371 |
-
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
|
| 372 |
-
Outputs a ConditionType, after the input data was embedded as a dense vector.
|
| 373 |
-
|
| 374 |
-
Returns:
|
| 375 |
-
ConditionType:
|
| 376 |
-
- A tensor of size [B, T, D] where B is the batch size, T is the length of the
|
| 377 |
-
output embedding and D is the dimension of the embedding.
|
| 378 |
-
- And a mask indicating where the padding tokens.
|
| 379 |
-
"""
|
| 380 |
-
raise NotImplementedError()
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
class TextConditioner(BaseConditioner):
|
| 384 |
-
...
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
class LUTConditioner(TextConditioner):
|
| 388 |
-
"""Lookup table TextConditioner.
|
| 389 |
-
|
| 390 |
-
Args:
|
| 391 |
-
n_bins (int): Number of bins.
|
| 392 |
-
dim (int): Hidden dim of the model (text-encoder/LUT).
|
| 393 |
-
output_dim (int): Output dim of the conditioner.
|
| 394 |
-
tokenizer (str): Name of the tokenizer.
|
| 395 |
-
pad_idx (int, optional): Index for padding token. Defaults to 0.
|
| 396 |
-
"""
|
| 397 |
-
def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
|
| 398 |
-
super().__init__(dim, output_dim)
|
| 399 |
-
self.embed = nn.Embedding(n_bins, dim)
|
| 400 |
-
self.tokenizer: Tokenizer
|
| 401 |
-
if tokenizer == 'whitespace':
|
| 402 |
-
self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
|
| 403 |
-
elif tokenizer == 'noop':
|
| 404 |
-
self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
|
| 405 |
-
else:
|
| 406 |
-
raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
|
| 407 |
-
|
| 408 |
-
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 409 |
-
device = self.embed.weight.device
|
| 410 |
-
tokens, mask = self.tokenizer(x)
|
| 411 |
-
tokens, mask = tokens.to(device), mask.to(device)
|
| 412 |
-
return tokens, mask
|
| 413 |
-
|
| 414 |
-
def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
|
| 415 |
-
tokens, mask = inputs
|
| 416 |
-
embeds = self.embed(tokens)
|
| 417 |
-
embeds = self.output_proj(embeds)
|
| 418 |
-
embeds = (embeds * mask.unsqueeze(-1))
|
| 419 |
-
return embeds, mask
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
class T5Conditioner(TextConditioner):
|
| 423 |
-
"""T5-based TextConditioner.
|
| 424 |
-
|
| 425 |
-
Args:
|
| 426 |
-
name (str): Name of the T5 model.
|
| 427 |
-
output_dim (int): Output dim of the conditioner.
|
| 428 |
-
finetune (bool): Whether to fine-tune T5 at train time.
|
| 429 |
-
device (str): Device for T5 Conditioner.
|
| 430 |
-
autocast_dtype (tp.Optional[str], optional): Autocast dtype.
|
| 431 |
-
word_dropout (float, optional): Word dropout probability.
|
| 432 |
-
normalize_text (bool, optional): Whether to apply text normalization.
|
| 433 |
-
"""
|
| 434 |
-
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
| 435 |
-
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
|
| 436 |
-
"google/flan-t5-xl", "google/flan-t5-xxl"]
|
| 437 |
-
MODELS_DIMS = {
|
| 438 |
-
"t5-small": 512,
|
| 439 |
-
"t5-base": 768,
|
| 440 |
-
"t5-large": 1024,
|
| 441 |
-
"t5-3b": 1024,
|
| 442 |
-
"t5-11b": 1024,
|
| 443 |
-
"google/flan-t5-small": 512,
|
| 444 |
-
"google/flan-t5-base": 768,
|
| 445 |
-
"google/flan-t5-large": 1024,
|
| 446 |
-
"google/flan-t5-3b": 1024,
|
| 447 |
-
"google/flan-t5-11b": 1024,
|
| 448 |
-
}
|
| 449 |
-
|
| 450 |
-
def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
|
| 451 |
-
autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
|
| 452 |
-
normalize_text: bool = False):
|
| 453 |
-
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
| 454 |
-
super().__init__(self.MODELS_DIMS[name], output_dim)
|
| 455 |
-
self.device = device
|
| 456 |
-
self.name = name
|
| 457 |
-
self.finetune = finetune
|
| 458 |
-
self.word_dropout = word_dropout
|
| 459 |
-
if autocast_dtype is None or self.device == 'cpu':
|
| 460 |
-
self.autocast = TorchAutocast(enabled=False)
|
| 461 |
-
if self.device != 'cpu':
|
| 462 |
-
logger.warning("T5 has no autocast, this might lead to NaN")
|
| 463 |
-
else:
|
| 464 |
-
dtype = getattr(torch, autocast_dtype)
|
| 465 |
-
assert isinstance(dtype, torch.dtype)
|
| 466 |
-
logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
|
| 467 |
-
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
|
| 468 |
-
# Let's disable logging temporarily because T5 will vomit some errors otherwise.
|
| 469 |
-
# thanks https://gist.github.com/simon-weber/7853144
|
| 470 |
-
previous_level = logging.root.manager.disable
|
| 471 |
-
logging.disable(logging.ERROR)
|
| 472 |
-
with warnings.catch_warnings():
|
| 473 |
-
warnings.simplefilter("ignore")
|
| 474 |
-
try:
|
| 475 |
-
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
|
| 476 |
-
t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
|
| 477 |
-
finally:
|
| 478 |
-
logging.disable(previous_level)
|
| 479 |
-
if finetune:
|
| 480 |
-
self.t5 = t5
|
| 481 |
-
else:
|
| 482 |
-
# this makes sure that the t5 models is not part
|
| 483 |
-
# of the saved checkpoint
|
| 484 |
-
self.__dict__['t5'] = t5.to(device)
|
| 485 |
-
|
| 486 |
-
self.normalize_text = normalize_text
|
| 487 |
-
if normalize_text:
|
| 488 |
-
self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
|
| 489 |
-
|
| 490 |
-
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
|
| 491 |
-
# if current sample doesn't have a certain attribute, replace with empty string
|
| 492 |
-
entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
|
| 493 |
-
if self.normalize_text:
|
| 494 |
-
_, _, entries = self.text_normalizer(entries, return_text=True)
|
| 495 |
-
if self.word_dropout > 0. and self.training:
|
| 496 |
-
new_entries = []
|
| 497 |
-
for entry in entries:
|
| 498 |
-
words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
|
| 499 |
-
new_entries.append(" ".join(words))
|
| 500 |
-
entries = new_entries
|
| 501 |
-
|
| 502 |
-
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
|
| 503 |
-
|
| 504 |
-
inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device)
|
| 505 |
-
mask = inputs['attention_mask']
|
| 506 |
-
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
| 507 |
-
return inputs
|
| 508 |
-
|
| 509 |
-
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
|
| 510 |
-
mask = inputs['attention_mask']
|
| 511 |
-
with torch.set_grad_enabled(self.finetune), self.autocast:
|
| 512 |
-
embeds = self.t5(**inputs).last_hidden_state
|
| 513 |
-
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
| 514 |
-
embeds = (embeds * mask.unsqueeze(-1))
|
| 515 |
-
return embeds, mask
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
class WaveformConditioner(BaseConditioner):
|
| 519 |
-
"""Base class for all conditioners that take a waveform as input.
|
| 520 |
-
Classes that inherit must implement `_get_wav_embedding` that outputs
|
| 521 |
-
a continuous tensor, and `_downsampling_factor` that returns the down-sampling
|
| 522 |
-
factor of the embedding model.
|
| 523 |
-
|
| 524 |
-
Args:
|
| 525 |
-
dim (int): The internal representation dimension.
|
| 526 |
-
output_dim (int): Output dimension.
|
| 527 |
-
device (tp.Union[torch.device, str]): Device.
|
| 528 |
-
"""
|
| 529 |
-
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
|
| 530 |
-
super().__init__(dim, output_dim)
|
| 531 |
-
self.device = device
|
| 532 |
-
# if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample.
|
| 533 |
-
self._use_masking = True
|
| 534 |
-
|
| 535 |
-
def tokenize(self, x: WavCondition) -> WavCondition:
|
| 536 |
-
wav, length, sample_rate, path, seek_time = x
|
| 537 |
-
assert length is not None
|
| 538 |
-
return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
|
| 539 |
-
|
| 540 |
-
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
| 541 |
-
"""Gets as input a WavCondition and returns a dense embedding."""
|
| 542 |
-
raise NotImplementedError()
|
| 543 |
-
|
| 544 |
-
def _downsampling_factor(self):
|
| 545 |
-
"""Returns the downsampling factor of the embedding model."""
|
| 546 |
-
raise NotImplementedError()
|
| 547 |
-
|
| 548 |
-
def forward(self, x: WavCondition) -> ConditionType:
|
| 549 |
-
"""Extract condition embedding and mask from a waveform and its metadata.
|
| 550 |
-
Args:
|
| 551 |
-
x (WavCondition): Waveform condition containing raw waveform and metadata.
|
| 552 |
-
Returns:
|
| 553 |
-
ConditionType: a dense vector representing the conditioning along with its mask
|
| 554 |
-
"""
|
| 555 |
-
wav, lengths, *_ = x
|
| 556 |
-
with torch.no_grad():
|
| 557 |
-
embeds = self._get_wav_embedding(x)
|
| 558 |
-
if hasattr(self, 'output_proj'):
|
| 559 |
-
embeds = embeds.to(self.output_proj.weight)
|
| 560 |
-
embeds = self.output_proj(embeds)
|
| 561 |
-
|
| 562 |
-
if lengths is not None and self._use_masking:
|
| 563 |
-
lengths = lengths / self._downsampling_factor()
|
| 564 |
-
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
|
| 565 |
-
else:
|
| 566 |
-
mask = torch.ones_like(embeds[..., 0])
|
| 567 |
-
embeds = (embeds * mask.unsqueeze(-1))
|
| 568 |
-
return embeds, mask
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
class ChromaStemConditioner(WaveformConditioner):
|
| 572 |
-
"""Chroma conditioner based on stems.
|
| 573 |
-
The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
|
| 574 |
-
the drums and bass often dominate the chroma leading to the chroma features
|
| 575 |
-
not containing information about the melody.
|
| 576 |
-
|
| 577 |
-
Args:
|
| 578 |
-
output_dim (int): Output dimension for the conditioner.
|
| 579 |
-
sample_rate (int): Sample rate for the chroma extractor.
|
| 580 |
-
n_chroma (int): Number of chroma bins for the chroma extractor.
|
| 581 |
-
radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
|
| 582 |
-
duration (int): duration used during training. This is later used for correct padding
|
| 583 |
-
in case we are using chroma as prefix.
|
| 584 |
-
match_len_on_eval (bool, optional): if True then all chromas are padded to the training
|
| 585 |
-
duration. Defaults to False.
|
| 586 |
-
eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
|
| 587 |
-
conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
|
| 588 |
-
Defaults to None.
|
| 589 |
-
n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
|
| 590 |
-
device (tp.Union[torch.device, str], optional): Device for the conditioner.
|
| 591 |
-
**kwargs: Additional parameters for the chroma extractor.
|
| 592 |
-
"""
|
| 593 |
-
def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
|
| 594 |
-
duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
|
| 595 |
-
n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
|
| 596 |
-
device: tp.Union[torch.device, str] = 'cpu', **kwargs):
|
| 597 |
-
from demucs import pretrained
|
| 598 |
-
super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
|
| 599 |
-
self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
|
| 600 |
-
self.sample_rate = sample_rate
|
| 601 |
-
self.match_len_on_eval = match_len_on_eval
|
| 602 |
-
if match_len_on_eval:
|
| 603 |
-
self._use_masking = False
|
| 604 |
-
self.duration = duration
|
| 605 |
-
self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
|
| 606 |
-
stem_sources: list = self.demucs.sources # type: ignore
|
| 607 |
-
self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
|
| 608 |
-
self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
|
| 609 |
-
radix2_exp=radix2_exp, **kwargs).to(device)
|
| 610 |
-
self.chroma_len = self._get_chroma_len()
|
| 611 |
-
self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
|
| 612 |
-
self.cache = None
|
| 613 |
-
if cache_path is not None:
|
| 614 |
-
self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
| 615 |
-
compute_embed_fn=self._get_full_chroma_for_cache,
|
| 616 |
-
extract_embed_fn=self._extract_chroma_chunk)
|
| 617 |
-
|
| 618 |
-
def _downsampling_factor(self) -> int:
|
| 619 |
-
return self.chroma.winhop
|
| 620 |
-
|
| 621 |
-
def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
|
| 622 |
-
"""Load pre-defined waveforms from a json.
|
| 623 |
-
These waveforms will be used for chroma extraction during evaluation.
|
| 624 |
-
This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
|
| 625 |
-
"""
|
| 626 |
-
if path is None:
|
| 627 |
-
return None
|
| 628 |
-
|
| 629 |
-
logger.info(f"Loading evaluation wavs from {path}")
|
| 630 |
-
from audiocraft.data.audio_dataset import AudioDataset
|
| 631 |
-
dataset: AudioDataset = AudioDataset.from_meta(
|
| 632 |
-
path, segment_duration=self.duration, min_audio_duration=self.duration,
|
| 633 |
-
sample_rate=self.sample_rate, channels=1)
|
| 634 |
-
|
| 635 |
-
if len(dataset) > 0:
|
| 636 |
-
eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
|
| 637 |
-
logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
|
| 638 |
-
return eval_wavs
|
| 639 |
-
else:
|
| 640 |
-
raise ValueError("Could not find evaluation wavs, check lengths of wavs")
|
| 641 |
-
|
| 642 |
-
def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
|
| 643 |
-
self.eval_wavs = eval_wavs
|
| 644 |
-
|
| 645 |
-
def has_eval_wavs(self) -> bool:
|
| 646 |
-
return self.eval_wavs is not None
|
| 647 |
-
|
| 648 |
-
def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
|
| 649 |
-
"""Sample wavs from a predefined list."""
|
| 650 |
-
assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
|
| 651 |
-
total_eval_wavs = len(self.eval_wavs)
|
| 652 |
-
out = self.eval_wavs
|
| 653 |
-
if num_samples > total_eval_wavs:
|
| 654 |
-
out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
|
| 655 |
-
return out[torch.randperm(len(out))][:num_samples]
|
| 656 |
-
|
| 657 |
-
def _get_chroma_len(self) -> int:
|
| 658 |
-
"""Get length of chroma during training."""
|
| 659 |
-
dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
|
| 660 |
-
dummy_chr = self.chroma(dummy_wav)
|
| 661 |
-
return dummy_chr.shape[1]
|
| 662 |
-
|
| 663 |
-
@torch.no_grad()
|
| 664 |
-
def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
| 665 |
-
"""Get parts of the wav that holds the melody, extracting the main stems from the wav."""
|
| 666 |
-
from demucs.apply import apply_model
|
| 667 |
-
from demucs.audio import convert_audio
|
| 668 |
-
with self.autocast:
|
| 669 |
-
wav = convert_audio(
|
| 670 |
-
wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
|
| 671 |
-
stems = apply_model(self.demucs, wav, device=self.device) # type: ignore
|
| 672 |
-
stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
|
| 673 |
-
mix_wav = stems.sum(1) # merge extracted stems to single waveform
|
| 674 |
-
mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
|
| 675 |
-
return mix_wav
|
| 676 |
-
|
| 677 |
-
@torch.no_grad()
|
| 678 |
-
def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
|
| 679 |
-
"""Extract chroma features from the waveform."""
|
| 680 |
-
with self.autocast:
|
| 681 |
-
return self.chroma(wav)
|
| 682 |
-
|
| 683 |
-
@torch.no_grad()
|
| 684 |
-
def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
| 685 |
-
"""Compute wav embedding, applying stem and chroma extraction."""
|
| 686 |
-
# avoid 0-size tensors when we are working with null conds
|
| 687 |
-
if wav.shape[-1] == 1:
|
| 688 |
-
return self._extract_chroma(wav)
|
| 689 |
-
stems = self._get_stemmed_wav(wav, sample_rate)
|
| 690 |
-
chroma = self._extract_chroma(stems)
|
| 691 |
-
return chroma
|
| 692 |
-
|
| 693 |
-
@torch.no_grad()
|
| 694 |
-
def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
|
| 695 |
-
"""Extract chroma from the whole audio waveform at the given path."""
|
| 696 |
-
wav, sr = audio_read(path)
|
| 697 |
-
wav = wav[None].to(self.device)
|
| 698 |
-
wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
|
| 699 |
-
chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
|
| 700 |
-
return chroma
|
| 701 |
-
|
| 702 |
-
def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
|
| 703 |
-
"""Extract a chunk of chroma from the full chroma derived from the full waveform."""
|
| 704 |
-
wav_length = x.wav.shape[-1]
|
| 705 |
-
seek_time = x.seek_time[idx]
|
| 706 |
-
assert seek_time is not None, (
|
| 707 |
-
"WavCondition seek_time is required "
|
| 708 |
-
"when extracting chroma chunks from pre-computed chroma.")
|
| 709 |
-
full_chroma = full_chroma.float()
|
| 710 |
-
frame_rate = self.sample_rate / self._downsampling_factor()
|
| 711 |
-
target_length = int(frame_rate * wav_length / self.sample_rate)
|
| 712 |
-
index = int(frame_rate * seek_time)
|
| 713 |
-
out = full_chroma[index: index + target_length]
|
| 714 |
-
out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
|
| 715 |
-
return out.to(self.device)
|
| 716 |
-
|
| 717 |
-
@torch.no_grad()
|
| 718 |
-
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
| 719 |
-
"""Get the wav embedding from the WavCondition.
|
| 720 |
-
The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
|
| 721 |
-
or will rely on the embedding cache to load the pre-computed embedding if relevant.
|
| 722 |
-
"""
|
| 723 |
-
sampled_wav: tp.Optional[torch.Tensor] = None
|
| 724 |
-
if not self.training and self.eval_wavs is not None:
|
| 725 |
-
warn_once(logger, "Using precomputed evaluation wavs!")
|
| 726 |
-
sampled_wav = self._sample_eval_wavs(len(x.wav))
|
| 727 |
-
|
| 728 |
-
no_undefined_paths = all(p is not None for p in x.path)
|
| 729 |
-
no_nullified_cond = x.wav.shape[-1] > 1
|
| 730 |
-
if sampled_wav is not None:
|
| 731 |
-
chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
|
| 732 |
-
elif self.cache is not None and no_undefined_paths and no_nullified_cond:
|
| 733 |
-
paths = [Path(p) for p in x.path if p is not None]
|
| 734 |
-
chroma = self.cache.get_embed_from_cache(paths, x)
|
| 735 |
-
else:
|
| 736 |
-
assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
|
| 737 |
-
chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
|
| 738 |
-
|
| 739 |
-
if self.match_len_on_eval:
|
| 740 |
-
B, T, C = chroma.shape
|
| 741 |
-
if T > self.chroma_len:
|
| 742 |
-
chroma = chroma[:, :self.chroma_len]
|
| 743 |
-
logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
|
| 744 |
-
elif T < self.chroma_len:
|
| 745 |
-
n_repeat = int(math.ceil(self.chroma_len / T))
|
| 746 |
-
chroma = chroma.repeat(1, n_repeat, 1)
|
| 747 |
-
chroma = chroma[:, :self.chroma_len]
|
| 748 |
-
logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
|
| 749 |
-
|
| 750 |
-
return chroma
|
| 751 |
-
|
| 752 |
-
def tokenize(self, x: WavCondition) -> WavCondition:
|
| 753 |
-
"""Apply WavConditioner tokenization and populate cache if needed."""
|
| 754 |
-
x = super().tokenize(x)
|
| 755 |
-
no_undefined_paths = all(p is not None for p in x.path)
|
| 756 |
-
if self.cache is not None and no_undefined_paths:
|
| 757 |
-
paths = [Path(p) for p in x.path if p is not None]
|
| 758 |
-
self.cache.populate_embed_cache(paths, x)
|
| 759 |
-
return x
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
class FeatureExtractor(WaveformConditioner):
|
| 763 |
-
"""
|
| 764 |
-
Feature Extractor used for the style conditioner of the paper AUDIO CONDITIONING
|
| 765 |
-
FOR MUSIC GENERATION VIA DISCRETE BOTTLENECK FEATURES.
|
| 766 |
-
|
| 767 |
-
Given a waveform, we extract an excerpt of defined length randomly subsampled.
|
| 768 |
-
Then, we feed this excerpt to a feature extractor.
|
| 769 |
-
|
| 770 |
-
Args:
|
| 771 |
-
model_name (str): 'encodec' or 'mert'.
|
| 772 |
-
sample_rate (str): sample rate of the input audio. (32000)
|
| 773 |
-
encodec_checkpoint (str): if encodec is used as a feature extractor, checkpoint
|
| 774 |
-
of the model. ('//pretrained/facebook/encodec_32khz' is the default)
|
| 775 |
-
encodec_n_q (int): if encodec is used as a feature extractor it sets the number of
|
| 776 |
-
quantization streams used in it.
|
| 777 |
-
length (float): length in seconds of the random subsampled excerpt that is used
|
| 778 |
-
for conditioning.
|
| 779 |
-
dim (int): The internal representation dimension.
|
| 780 |
-
output_dim (int): Output dimension for the conditioner.
|
| 781 |
-
device (tp.Union[torch.device, str], optional): Device for the conditioner.
|
| 782 |
-
compute_mask (bool): whether to mask the tokens corresponding to the subsampled
|
| 783 |
-
excerpt in the computation of the music language model cross-entropy loss.
|
| 784 |
-
use_middle_of_segment (bool): if True, always take the middle of the input
|
| 785 |
-
instead of a random subsampled excerpt.
|
| 786 |
-
ds_rate_compression (int): downsampling parameter of the compression model used
|
| 787 |
-
for the music language model. (640 for encodec_32khz)
|
| 788 |
-
num_codebooks_lm (int): the number of codebooks used by the music language model.
|
| 789 |
-
"""
|
| 790 |
-
def __init__(
|
| 791 |
-
self, model_name: str,
|
| 792 |
-
sample_rate: int, encodec_checkpoint: str, encodec_n_q: int, length: float,
|
| 793 |
-
dim: int, output_dim: int, device: tp.Union[torch.device, str],
|
| 794 |
-
compute_mask: bool = True,
|
| 795 |
-
use_middle_of_segment: bool = False, ds_rate_compression: int = 640,
|
| 796 |
-
num_codebooks_lm: int = 4
|
| 797 |
-
):
|
| 798 |
-
assert model_name in ['encodec', 'mert']
|
| 799 |
-
if model_name == 'encodec':
|
| 800 |
-
from ..solvers.compression import CompressionSolver
|
| 801 |
-
feat_extractor = CompressionSolver.model_from_checkpoint(encodec_checkpoint, device)
|
| 802 |
-
elif model_name == 'mert':
|
| 803 |
-
from transformers import AutoModel
|
| 804 |
-
feat_extractor = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
|
| 805 |
-
super().__init__(
|
| 806 |
-
dim=dim,
|
| 807 |
-
output_dim=output_dim,
|
| 808 |
-
device=device
|
| 809 |
-
)
|
| 810 |
-
self.sample_rate = sample_rate
|
| 811 |
-
self.compute_mask = compute_mask
|
| 812 |
-
self.feat_extractor: nn.Module
|
| 813 |
-
self.embed: tp.Union[nn.ModuleList, nn.Linear]
|
| 814 |
-
if model_name == 'encodec':
|
| 815 |
-
self.__dict__["feat_extractor"] = feat_extractor.to(device)
|
| 816 |
-
self.encodec_n_q = encodec_n_q
|
| 817 |
-
self.embed = nn.ModuleList([nn.Embedding(feat_extractor.cardinality, dim) for _ in range(encodec_n_q)])
|
| 818 |
-
if model_name == 'mert':
|
| 819 |
-
self.__dict__["feat_extractor"] = feat_extractor.eval().to(device)
|
| 820 |
-
self.embed = nn.Linear(768, dim) # hardcoded
|
| 821 |
-
self.length_subwav = int(length * sample_rate)
|
| 822 |
-
self.ds_rate_compression = ds_rate_compression
|
| 823 |
-
self.model_name = model_name
|
| 824 |
-
self.use_middle_of_segment = use_middle_of_segment
|
| 825 |
-
self.num_codebooks_lm = num_codebooks_lm
|
| 826 |
-
|
| 827 |
-
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
| 828 |
-
if x.wav.shape[-1] == 1:
|
| 829 |
-
self.temp_mask = None
|
| 830 |
-
return torch.zeros(x.wav.shape[0], 1, self.dim, device=self.device)
|
| 831 |
-
else:
|
| 832 |
-
with torch.no_grad():
|
| 833 |
-
if self.use_middle_of_segment:
|
| 834 |
-
start = int((x.wav.shape[-1] - self.length_subwav) / 2)
|
| 835 |
-
wav = x.wav[:, :, start:start+self.length_subwav]
|
| 836 |
-
else:
|
| 837 |
-
start = random.randint(0, x.wav.shape[-1] - self.length_subwav)
|
| 838 |
-
wav = x.wav[:, :, start:start+self.length_subwav]
|
| 839 |
-
if self.compute_mask:
|
| 840 |
-
self.temp_mask = self._get_mask_wav(x, start)
|
| 841 |
-
if self.model_name == 'encodec':
|
| 842 |
-
tokens = self.feat_extractor.encode(wav)[0] # type: ignore
|
| 843 |
-
elif self.model_name == 'mert':
|
| 844 |
-
wav = convert_audio(wav, from_rate=x.sample_rate[0], to_rate=24000, to_channels=1)
|
| 845 |
-
embeds = self.feat_extractor(wav.squeeze(-2)).last_hidden_state
|
| 846 |
-
if self.model_name == 'encodec':
|
| 847 |
-
tokens = tokens[:, :self.encodec_n_q]
|
| 848 |
-
embeds = sum([self.embed[k](tokens[:, k]) for k in range(self.encodec_n_q)]) # type: ignore
|
| 849 |
-
else:
|
| 850 |
-
embeds = self.embed(embeds)
|
| 851 |
-
|
| 852 |
-
return embeds # [B, T, dim]
|
| 853 |
-
|
| 854 |
-
def _downsampling_factor(self):
|
| 855 |
-
if self.model_name == 'encodec':
|
| 856 |
-
return self.sample_rate / self.feat_extractor.frame_rate
|
| 857 |
-
elif self.model_name == 'mert':
|
| 858 |
-
return self.sample_rate / 75
|
| 859 |
-
|
| 860 |
-
def _get_mask_wav(self, x: WavCondition, start: int) -> tp.Union[torch.Tensor, None]:
|
| 861 |
-
if x.wav.shape[-1] == 1:
|
| 862 |
-
return None
|
| 863 |
-
total_length = int(x.wav.shape[-1] / self.ds_rate_compression)
|
| 864 |
-
mask_length = int(self.length_subwav / self.ds_rate_compression)
|
| 865 |
-
start = int(start / self.ds_rate_compression)
|
| 866 |
-
mask = torch.ones(x.wav.shape[0], self.num_codebooks_lm,
|
| 867 |
-
total_length, device=self.device, dtype=torch.bool)
|
| 868 |
-
mask[:, :, start:start+mask_length] = 0
|
| 869 |
-
return mask
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
class StyleConditioner(FeatureExtractor):
|
| 873 |
-
"""Conditioner from the paper AUDIO CONDITIONING FOR MUSIC GENERATION VIA
|
| 874 |
-
DISCRETE BOTTLENECK FEATURES.
|
| 875 |
-
Given an audio input, it is passed through a Feature Extractor and a
|
| 876 |
-
transformer encoder. Then it is quantized through RVQ.
|
| 877 |
-
|
| 878 |
-
Args:
|
| 879 |
-
transformer_scale (str): size of the transformer. See in the __init__ to have more infos.
|
| 880 |
-
ds_factor (int): the downsampling factor applied to the representation after quantization.
|
| 881 |
-
encodec_n_q (int): if encodec is used as a feature extractor it sets the number of
|
| 882 |
-
quantization streams used in it.
|
| 883 |
-
n_q_out (int): the number of quantization streams used for the RVQ. If increased, there
|
| 884 |
-
is more information passing as a conditioning.
|
| 885 |
-
eval_q (int): the number of quantization streams used for the RVQ at evaluation time.
|
| 886 |
-
q_dropout (bool): if True, at training time, a random number of stream is sampled
|
| 887 |
-
at each step in the interval [1, n_q_out].
|
| 888 |
-
bins (int): the codebook size used for each quantization stream.
|
| 889 |
-
varying_lengths (List[float]): list of the min and max duration in seconds for the
|
| 890 |
-
randomly subsampled excerpt at training time. For each step a length is sampled
|
| 891 |
-
in this interval.
|
| 892 |
-
batch_norm (bool): use of batch normalization after the transformer. Stabilizes the
|
| 893 |
-
training.
|
| 894 |
-
rvq_threshold_ema_dead_code (float): threshold for dropping dead codes in the
|
| 895 |
-
RVQ.
|
| 896 |
-
"""
|
| 897 |
-
def __init__(self, transformer_scale: str = 'default', ds_factor: int = 15, encodec_n_q: int = 4,
|
| 898 |
-
n_q_out: int = 6, eval_q: int = 3, q_dropout: bool = True, bins: int = 1024,
|
| 899 |
-
varying_lengths: tp.List[float] = [1.5, 4.5],
|
| 900 |
-
batch_norm: bool = True, rvq_threshold_ema_dead_code: float = 0.1,
|
| 901 |
-
**kwargs):
|
| 902 |
-
tr_args: tp.Dict[str, tp.Any]
|
| 903 |
-
if transformer_scale == 'xsmall':
|
| 904 |
-
tr_args = {'d_model': 256, 'num_heads': 8, 'num_layers': 4}
|
| 905 |
-
elif transformer_scale == 'large':
|
| 906 |
-
tr_args = {'d_model': 1024, 'num_heads': 16, 'num_layers': 24}
|
| 907 |
-
elif transformer_scale == 'default':
|
| 908 |
-
tr_args = {'d_model': 512, 'num_heads': 8, 'num_layers': 8}
|
| 909 |
-
elif transformer_scale == 'none':
|
| 910 |
-
tr_args = {'d_model': 512}
|
| 911 |
-
tr_args.update({
|
| 912 |
-
'memory_efficient': True, 'activation': 'gelu',
|
| 913 |
-
'norm_first': True, 'causal': False, 'layer_scale': None,
|
| 914 |
-
'bias_ff': False, 'bias_attn': False,
|
| 915 |
-
})
|
| 916 |
-
dim = tr_args['d_model']
|
| 917 |
-
super().__init__(dim=dim, encodec_n_q=encodec_n_q, **kwargs)
|
| 918 |
-
|
| 919 |
-
self.ds_factor = ds_factor
|
| 920 |
-
if transformer_scale == 'none':
|
| 921 |
-
self.transformer = None
|
| 922 |
-
else:
|
| 923 |
-
self.transformer = StreamingTransformer(dim_feedforward=int(4 * dim), **tr_args)
|
| 924 |
-
self.n_q_out = n_q_out
|
| 925 |
-
self.eval_q = eval_q
|
| 926 |
-
self.rvq = None
|
| 927 |
-
if n_q_out > 0:
|
| 928 |
-
self.rvq = ResidualVectorQuantizer(dim, n_q=n_q_out, q_dropout=q_dropout, bins=bins,
|
| 929 |
-
threshold_ema_dead_code=rvq_threshold_ema_dead_code)
|
| 930 |
-
self.autocast = TorchAutocast(enabled=self.device != 'cpu', device_type=self.device, dtype=torch.float32)
|
| 931 |
-
self.varying_lengths = varying_lengths
|
| 932 |
-
self.batch_norm = None
|
| 933 |
-
if batch_norm:
|
| 934 |
-
self.batch_norm = nn.BatchNorm1d(dim, affine=False)
|
| 935 |
-
self.mask = None
|
| 936 |
-
|
| 937 |
-
def _get_wav_embedding(self, wav: WavCondition) -> torch.Tensor:
|
| 938 |
-
with self.autocast:
|
| 939 |
-
# Sample the length of the excerpts
|
| 940 |
-
if self.varying_lengths and self.training:
|
| 941 |
-
assert len(self.varying_lengths) == 2
|
| 942 |
-
length = random.uniform(self.varying_lengths[0], self.varying_lengths[1])
|
| 943 |
-
self.length_subwav = int(length * self.sample_rate)
|
| 944 |
-
z1 = super()._get_wav_embedding(wav)
|
| 945 |
-
if self.compute_mask:
|
| 946 |
-
self.mask = self.temp_mask # type: ignore
|
| 947 |
-
self.temp_mask = None
|
| 948 |
-
|
| 949 |
-
if self.transformer is not None:
|
| 950 |
-
out1 = self.transformer(z1)
|
| 951 |
-
else:
|
| 952 |
-
out1 = z1
|
| 953 |
-
if self.batch_norm:
|
| 954 |
-
out1 = self.batch_norm(out1.transpose(1, 2)).transpose(1, 2)
|
| 955 |
-
# Apply quantization
|
| 956 |
-
if self.rvq:
|
| 957 |
-
if self.training:
|
| 958 |
-
self.rvq.set_num_codebooks(self.n_q_out)
|
| 959 |
-
else:
|
| 960 |
-
self.rvq.set_num_codebooks(self.eval_q)
|
| 961 |
-
out1 = self.rvq(out1.transpose(1, 2), frame_rate=1.)
|
| 962 |
-
if self.training:
|
| 963 |
-
flashy.distrib.average_tensors(self.rvq.buffers())
|
| 964 |
-
out1 = out1.x.transpose(1, 2)
|
| 965 |
-
# Apply fix downsample
|
| 966 |
-
out1 = out1[:, ::self.ds_factor]
|
| 967 |
-
|
| 968 |
-
return out1
|
| 969 |
-
|
| 970 |
-
def set_params(self, eval_q: int = 3,
|
| 971 |
-
excerpt_length: float = 3.0,
|
| 972 |
-
ds_factor: tp.Optional[int] = None, encodec_n_q: tp.Optional[int] = None):
|
| 973 |
-
"""Modify the parameters of the SSL or introduce new parameters to add noise to
|
| 974 |
-
the conditioning or to downsample it
|
| 975 |
-
|
| 976 |
-
Args:
|
| 977 |
-
eval_q (int): number of codebooks used when evaluating the model
|
| 978 |
-
excerpt_length (float): the length of the excerpts used to condition the model
|
| 979 |
-
"""
|
| 980 |
-
self.eval_q = eval_q
|
| 981 |
-
self.length_subwav = int(excerpt_length * self.sample_rate)
|
| 982 |
-
if ds_factor is not None:
|
| 983 |
-
self.ds_factor = ds_factor
|
| 984 |
-
if encodec_n_q is not None:
|
| 985 |
-
self.encodec_n_q = encodec_n_q
|
| 986 |
-
|
| 987 |
-
def _downsampling_factor(self):
|
| 988 |
-
df = super()._downsampling_factor()
|
| 989 |
-
return df * self.ds_factor
|
| 990 |
-
|
| 991 |
-
def forward(self, x: WavCondition) -> ConditionType:
|
| 992 |
-
wav, lengths, *_ = x
|
| 993 |
-
|
| 994 |
-
embeds = self._get_wav_embedding(x)
|
| 995 |
-
embeds = embeds.to(self.output_proj.weight)
|
| 996 |
-
embeds = self.output_proj(embeds)
|
| 997 |
-
|
| 998 |
-
lengths = lengths / self._downsampling_factor()
|
| 999 |
-
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
|
| 1000 |
-
|
| 1001 |
-
embeds = (embeds * mask.unsqueeze(2).to(self.device))
|
| 1002 |
-
|
| 1003 |
-
return embeds, mask
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
class JointEmbeddingConditioner(BaseConditioner):
|
| 1007 |
-
"""Joint embedding conditioning supporting both audio or text conditioning.
|
| 1008 |
-
|
| 1009 |
-
Args:
|
| 1010 |
-
dim (int): Dimension.
|
| 1011 |
-
output_dim (int): Output dimension.
|
| 1012 |
-
device (str): Device.
|
| 1013 |
-
attribute (str): Attribute used by the conditioner.
|
| 1014 |
-
autocast_dtype (str): Autocast for the conditioner.
|
| 1015 |
-
quantize (bool): Whether to quantize the CLAP embedding.
|
| 1016 |
-
n_q (int): Number of residual quantizers (used if quantize is true).
|
| 1017 |
-
bins (int): Quantizers' codebooks size (used if quantize is true).
|
| 1018 |
-
kwargs: Additional parameters for residual vector quantizer.
|
| 1019 |
-
"""
|
| 1020 |
-
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
|
| 1021 |
-
autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
|
| 1022 |
-
n_q: int = 12, bins: int = 1024, **kwargs):
|
| 1023 |
-
super().__init__(dim=dim, output_dim=output_dim)
|
| 1024 |
-
self.device = device
|
| 1025 |
-
self.attribute = attribute
|
| 1026 |
-
if autocast_dtype is None or device == 'cpu':
|
| 1027 |
-
self.autocast = TorchAutocast(enabled=False)
|
| 1028 |
-
logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
|
| 1029 |
-
else:
|
| 1030 |
-
dtype = getattr(torch, autocast_dtype)
|
| 1031 |
-
assert isinstance(dtype, torch.dtype)
|
| 1032 |
-
logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
|
| 1033 |
-
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
|
| 1034 |
-
# residual vector quantizer to discretize the conditioned embedding
|
| 1035 |
-
self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
|
| 1036 |
-
if quantize:
|
| 1037 |
-
self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
|
| 1038 |
-
|
| 1039 |
-
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 1040 |
-
"""Get joint embedding in latent space from the inputs.
|
| 1041 |
-
|
| 1042 |
-
Returns:
|
| 1043 |
-
tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
|
| 1044 |
-
and corresponding empty indexes.
|
| 1045 |
-
"""
|
| 1046 |
-
raise NotImplementedError()
|
| 1047 |
-
|
| 1048 |
-
def forward(self, x: JointEmbedCondition) -> ConditionType:
|
| 1049 |
-
with self.autocast:
|
| 1050 |
-
embed, empty_idx = self._get_embed(x)
|
| 1051 |
-
if self.quantizer is not None:
|
| 1052 |
-
embed = embed.view(-1, self.dim, 1)
|
| 1053 |
-
q_res = self.quantizer(embed, frame_rate=1)
|
| 1054 |
-
out_embed = q_res.x.view(-1, self.dim)
|
| 1055 |
-
else:
|
| 1056 |
-
out_embed = embed
|
| 1057 |
-
out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
|
| 1058 |
-
mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
|
| 1059 |
-
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
| 1060 |
-
out_embed = (out_embed * mask.unsqueeze(-1))
|
| 1061 |
-
return out_embed, mask
|
| 1062 |
-
|
| 1063 |
-
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
| 1064 |
-
return x
|
| 1065 |
-
|
| 1066 |
-
|
| 1067 |
-
class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
|
| 1068 |
-
"""Joint Embedding conditioner based on pre-trained CLAP model.
|
| 1069 |
-
|
| 1070 |
-
This CLAP-based conditioner supports a caching mechanism
|
| 1071 |
-
over the computed embeddings for faster training.
|
| 1072 |
-
|
| 1073 |
-
Args:
|
| 1074 |
-
dim (int): Dimension.
|
| 1075 |
-
output_dim (int): Output dimension.
|
| 1076 |
-
device (str): Device.
|
| 1077 |
-
attribute (str): Attribute used by the conditioner.
|
| 1078 |
-
quantize (bool): Whether to quantize the CLAP embedding.
|
| 1079 |
-
n_q (int): Number of residual quantizers (used if quantize is true).
|
| 1080 |
-
bins (int): Quantizers' codebooks size (used if quantize is true).
|
| 1081 |
-
checkpoint (str): Path to CLAP checkpoint.
|
| 1082 |
-
model_arch (str): CLAP model architecture.
|
| 1083 |
-
enable_fusion (bool): Enable fusion for CLAP model.
|
| 1084 |
-
sample_rate (int): Sample rate used by CLAP model.
|
| 1085 |
-
max_audio_length (float): Maximum audio length for CLAP model.
|
| 1086 |
-
audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
|
| 1087 |
-
normalize (bool): Whether to normalize the CLAP embedding.
|
| 1088 |
-
text_p (float): Probability of using text representation instead of audio at train time.
|
| 1089 |
-
batch_size (Optional[int]): Batch size for CLAP embedding computation.
|
| 1090 |
-
autocast_dtype (str): Autocast for the conditioner.
|
| 1091 |
-
cache_path (Optional[str]): Path for pre-computed embeddings caching.
|
| 1092 |
-
kwargs: Additional parameters for residual vector quantizer.
|
| 1093 |
-
"""
|
| 1094 |
-
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
|
| 1095 |
-
quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
|
| 1096 |
-
enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
|
| 1097 |
-
normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
|
| 1098 |
-
autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
|
| 1099 |
-
try:
|
| 1100 |
-
import laion_clap # type: ignore
|
| 1101 |
-
except ImportError:
|
| 1102 |
-
raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
|
| 1103 |
-
warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
|
| 1104 |
-
"Please retrain all models.")
|
| 1105 |
-
checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
|
| 1106 |
-
clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
|
| 1107 |
-
clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
|
| 1108 |
-
load_clap_state_dict(clap_model, checkpoint)
|
| 1109 |
-
clap_model.eval()
|
| 1110 |
-
clap_model.to(device)
|
| 1111 |
-
super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
|
| 1112 |
-
autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
|
| 1113 |
-
**kwargs)
|
| 1114 |
-
self.checkpoint = checkpoint
|
| 1115 |
-
self.enable_fusion = enable_fusion
|
| 1116 |
-
self.model_arch = model_arch
|
| 1117 |
-
self.clap: laion_clap.CLAP_Module
|
| 1118 |
-
self.clap_tokenize: RobertaTokenizer
|
| 1119 |
-
self.clap_sample_rate = sample_rate
|
| 1120 |
-
self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
|
| 1121 |
-
self.clap_stride = int(self.clap_sample_rate * audio_stride)
|
| 1122 |
-
self.batch_size = batch_size or 1
|
| 1123 |
-
self.normalize = normalize
|
| 1124 |
-
self.text_p = text_p
|
| 1125 |
-
self.__dict__['clap_tokenize'] = clap_tokenize
|
| 1126 |
-
self.__dict__['clap'] = clap_model
|
| 1127 |
-
self.wav_cache, self.text_cache = None, None
|
| 1128 |
-
if cache_path is not None:
|
| 1129 |
-
self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
| 1130 |
-
compute_embed_fn=self._get_wav_embedding_for_cache,
|
| 1131 |
-
extract_embed_fn=self._extract_wav_embedding_chunk)
|
| 1132 |
-
self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
|
| 1133 |
-
compute_embed_fn=self._get_text_embedding_for_cache)
|
| 1134 |
-
|
| 1135 |
-
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
|
| 1136 |
-
# we use the default params from CLAP module here as well
|
| 1137 |
-
return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
|
| 1138 |
-
|
| 1139 |
-
def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
|
| 1140 |
-
"""Compute text embedding from CLAP model on a given a batch of text.
|
| 1141 |
-
|
| 1142 |
-
Args:
|
| 1143 |
-
text (list[str]): List of text for the batch, with B items.
|
| 1144 |
-
Returns:
|
| 1145 |
-
torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
|
| 1146 |
-
"""
|
| 1147 |
-
with torch.no_grad():
|
| 1148 |
-
embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
|
| 1149 |
-
return embed.view(embed.size(0), 1, embed.size(-1))
|
| 1150 |
-
|
| 1151 |
-
def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
|
| 1152 |
-
x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
| 1153 |
-
"""Get text embedding function for the cache."""
|
| 1154 |
-
text = x.text[idx]
|
| 1155 |
-
text = text if text is not None else ""
|
| 1156 |
-
return self._compute_text_embedding([text])[0]
|
| 1157 |
-
|
| 1158 |
-
def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
|
| 1159 |
-
"""Preprocess wav to expected format by CLAP model.
|
| 1160 |
-
|
| 1161 |
-
Args:
|
| 1162 |
-
wav (torch.Tensor): Audio wav, of shape [B, C, T].
|
| 1163 |
-
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
|
| 1164 |
-
sample_rates (list[int]): Sample rates for each sample in the batch
|
| 1165 |
-
Returns:
|
| 1166 |
-
torch.Tensor: Audio wav of shape [B, T].
|
| 1167 |
-
"""
|
| 1168 |
-
assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
|
| 1169 |
-
if sample_rates is not None:
|
| 1170 |
-
_wav = []
|
| 1171 |
-
for i, audio in enumerate(wav):
|
| 1172 |
-
sr = sample_rates[i]
|
| 1173 |
-
audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
|
| 1174 |
-
_wav.append(audio)
|
| 1175 |
-
wav = torch.stack(_wav, dim=0)
|
| 1176 |
-
wav = wav.mean(dim=1)
|
| 1177 |
-
return wav
|
| 1178 |
-
|
| 1179 |
-
def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
|
| 1180 |
-
sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
|
| 1181 |
-
"""Compute audio wave embedding from CLAP model.
|
| 1182 |
-
|
| 1183 |
-
Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
|
| 1184 |
-
we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
|
| 1185 |
-
average the resulting embeddings.
|
| 1186 |
-
|
| 1187 |
-
Args:
|
| 1188 |
-
wav (torch.Tensor): Audio wav, of shape [B, C, T].
|
| 1189 |
-
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
|
| 1190 |
-
sample_rates (list[int]): Sample rates for each sample in the batch.
|
| 1191 |
-
reduce_mean (bool): Whether to get the average tensor.
|
| 1192 |
-
Returns:
|
| 1193 |
-
torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
|
| 1194 |
-
"""
|
| 1195 |
-
with torch.no_grad():
|
| 1196 |
-
wav = self._preprocess_wav(wav, length, sample_rates)
|
| 1197 |
-
B, T = wav.shape
|
| 1198 |
-
if T >= self.clap_max_frames:
|
| 1199 |
-
wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
|
| 1200 |
-
else:
|
| 1201 |
-
wav = wav.view(-1, 1, T) # [B, F, T] with F=1
|
| 1202 |
-
wav = einops.rearrange(wav, 'b f t -> (b f) t')
|
| 1203 |
-
embed_list = []
|
| 1204 |
-
for i in range(0, wav.size(0), self.batch_size):
|
| 1205 |
-
_wav = wav[i:i+self.batch_size, ...]
|
| 1206 |
-
_embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
|
| 1207 |
-
embed_list.append(_embed)
|
| 1208 |
-
embed = torch.cat(embed_list, dim=0)
|
| 1209 |
-
embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
|
| 1210 |
-
if reduce_mean:
|
| 1211 |
-
embed = embed.mean(dim=1, keepdim=True)
|
| 1212 |
-
return embed # [B, F, D] with F=1 if reduce_mean is True
|
| 1213 |
-
|
| 1214 |
-
def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
|
| 1215 |
-
x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
| 1216 |
-
"""Compute audio wave embedding for the cache.
|
| 1217 |
-
The embedding is computed on a given audio read from file.
|
| 1218 |
-
|
| 1219 |
-
Args:
|
| 1220 |
-
path (str or Path): Path to the full audio file.
|
| 1221 |
-
Returns:
|
| 1222 |
-
torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
|
| 1223 |
-
"""
|
| 1224 |
-
wav, sr = audio_read(path) # [C, T]
|
| 1225 |
-
wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
|
| 1226 |
-
wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
|
| 1227 |
-
embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
|
| 1228 |
-
return embed.squeeze(0) # [F, D]
|
| 1229 |
-
|
| 1230 |
-
def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
| 1231 |
-
"""Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
|
| 1232 |
-
|
| 1233 |
-
Args:
|
| 1234 |
-
full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
|
| 1235 |
-
x (JointEmbedCondition): Joint embedding condition for the full batch.
|
| 1236 |
-
idx (int): Index considered for the given embedding to extract.
|
| 1237 |
-
Returns:
|
| 1238 |
-
torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
|
| 1239 |
-
"""
|
| 1240 |
-
sample_rate = x.sample_rate[idx]
|
| 1241 |
-
seek_time = x.seek_time[idx]
|
| 1242 |
-
seek_time = 0. if seek_time is None else seek_time
|
| 1243 |
-
clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
|
| 1244 |
-
end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
|
| 1245 |
-
start_offset = int(seek_time * sample_rate // clap_stride)
|
| 1246 |
-
end_offset = int(end_seek_time * sample_rate // clap_stride)
|
| 1247 |
-
wav_embed = full_embed[start_offset:end_offset, ...]
|
| 1248 |
-
wav_embed = wav_embed.mean(dim=0, keepdim=True)
|
| 1249 |
-
return wav_embed.to(self.device) # [F, D]
|
| 1250 |
-
|
| 1251 |
-
def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
|
| 1252 |
-
"""Get CLAP embedding from a batch of text descriptions."""
|
| 1253 |
-
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
|
| 1254 |
-
if self.text_cache is not None and no_nullified_cond:
|
| 1255 |
-
assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
|
| 1256 |
-
paths = [Path(p) for p in x.path if p is not None]
|
| 1257 |
-
embed = self.text_cache.get_embed_from_cache(paths, x)
|
| 1258 |
-
else:
|
| 1259 |
-
text = [xi if xi is not None else "" for xi in x.text]
|
| 1260 |
-
embed = self._compute_text_embedding(text)
|
| 1261 |
-
if self.normalize:
|
| 1262 |
-
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
|
| 1263 |
-
return embed
|
| 1264 |
-
|
| 1265 |
-
def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
|
| 1266 |
-
"""Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
|
| 1267 |
-
no_undefined_paths = all(p is not None for p in x.path)
|
| 1268 |
-
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
|
| 1269 |
-
if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
|
| 1270 |
-
paths = [Path(p) for p in x.path if p is not None]
|
| 1271 |
-
embed = self.wav_cache.get_embed_from_cache(paths, x)
|
| 1272 |
-
else:
|
| 1273 |
-
embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
|
| 1274 |
-
if self.normalize:
|
| 1275 |
-
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
|
| 1276 |
-
return embed
|
| 1277 |
-
|
| 1278 |
-
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
| 1279 |
-
# Trying to limit as much as possible sync points when the cache is warm.
|
| 1280 |
-
no_undefined_paths = all(p is not None for p in x.path)
|
| 1281 |
-
if self.wav_cache is not None and no_undefined_paths:
|
| 1282 |
-
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
|
| 1283 |
-
paths = [Path(p) for p in x.path if p is not None]
|
| 1284 |
-
self.wav_cache.populate_embed_cache(paths, x)
|
| 1285 |
-
if self.text_cache is not None and no_undefined_paths:
|
| 1286 |
-
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
|
| 1287 |
-
paths = [Path(p) for p in x.path if p is not None]
|
| 1288 |
-
self.text_cache.populate_embed_cache(paths, x)
|
| 1289 |
-
return x
|
| 1290 |
-
|
| 1291 |
-
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 1292 |
-
"""Extract shared latent representation from either the wav or the text using CLAP."""
|
| 1293 |
-
# decide whether to use text embedding at train time or not
|
| 1294 |
-
use_text_embed = random.random() < self.text_p
|
| 1295 |
-
if self.training and not use_text_embed:
|
| 1296 |
-
embed = self._get_wav_embedding(x)
|
| 1297 |
-
empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
|
| 1298 |
-
else:
|
| 1299 |
-
embed = self._get_text_embedding(x)
|
| 1300 |
-
empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
|
| 1301 |
-
return embed, empty_idx
|
| 1302 |
-
|
| 1303 |
-
|
| 1304 |
-
def dropout_symbolic_conditions(sample: ConditioningAttributes,
|
| 1305 |
-
condition: str, null_chord_idx: int = 194) -> ConditioningAttributes:
|
| 1306 |
-
"""
|
| 1307 |
-
Applies dropout to symbolic conditions within the sample based on the specified condition by setting the condition
|
| 1308 |
-
value to a null index.
|
| 1309 |
-
Args:
|
| 1310 |
-
sample (ConditioningAttributes): The sample containing symbolic attributes to potentially dropout.
|
| 1311 |
-
condition (str): The specific condition within the symbolic attributes to apply dropout.
|
| 1312 |
-
null_chord_idx (int, optional): The index used to represent a null chord. Defaults to 194.
|
| 1313 |
-
Returns:
|
| 1314 |
-
ConditioningAttributes: The modified sample with dropout applied to the specified condition.
|
| 1315 |
-
Raises:
|
| 1316 |
-
ValueError: If the specified condition is not present in the sample's symbolic attributes.
|
| 1317 |
-
"""
|
| 1318 |
-
if sample.symbolic == {} or sample.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] <= 1: # type: ignore
|
| 1319 |
-
# nothing to drop
|
| 1320 |
-
return sample
|
| 1321 |
-
|
| 1322 |
-
if condition not in getattr(sample, 'symbolic'):
|
| 1323 |
-
raise ValueError(
|
| 1324 |
-
"dropout_symbolic_condition received an unexpected condition!"
|
| 1325 |
-
f" expected {sample.symbolic.keys()}"
|
| 1326 |
-
f" but got '{condition}'!"
|
| 1327 |
-
)
|
| 1328 |
-
|
| 1329 |
-
if condition == JascoCondConst.CRD.value:
|
| 1330 |
-
sample.symbolic[condition] = nullify_chords(sample.symbolic[condition], null_chord_idx=null_chord_idx)
|
| 1331 |
-
elif condition == JascoCondConst.MLD.value:
|
| 1332 |
-
sample.symbolic[condition] = nullify_melody(sample.symbolic[condition])
|
| 1333 |
-
|
| 1334 |
-
return sample
|
| 1335 |
-
|
| 1336 |
-
|
| 1337 |
-
def dropout_condition(sample: ConditioningAttributes,
|
| 1338 |
-
condition_type: str, condition: str,
|
| 1339 |
-
**kwargs) -> ConditioningAttributes:
|
| 1340 |
-
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
| 1341 |
-
If the condition is of type "wav", then nullify it using `nullify_condition` function.
|
| 1342 |
-
If the condition is of any other type, set its value to None.
|
| 1343 |
-
Works in-place.
|
| 1344 |
-
"""
|
| 1345 |
-
if condition_type not in ['text', 'wav', 'joint_embed', 'symbolic']:
|
| 1346 |
-
raise ValueError(
|
| 1347 |
-
"dropout_condition got an unexpected condition type!"
|
| 1348 |
-
f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
|
| 1349 |
-
)
|
| 1350 |
-
|
| 1351 |
-
if condition not in getattr(sample, condition_type):
|
| 1352 |
-
raise ValueError(
|
| 1353 |
-
"dropout_condition received an unexpected condition!"
|
| 1354 |
-
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
|
| 1355 |
-
f" but got '{condition}' of type '{condition_type}'!"
|
| 1356 |
-
)
|
| 1357 |
-
|
| 1358 |
-
if condition_type == 'wav':
|
| 1359 |
-
wav_cond = sample.wav[condition]
|
| 1360 |
-
sample.wav[condition] = nullify_wav(wav_cond)
|
| 1361 |
-
elif condition_type == 'joint_embed':
|
| 1362 |
-
embed = sample.joint_embed[condition]
|
| 1363 |
-
sample.joint_embed[condition] = nullify_joint_embed(embed)
|
| 1364 |
-
elif condition_type == 'symbolic':
|
| 1365 |
-
sample = dropout_symbolic_conditions(sample=sample, condition=condition, **kwargs)
|
| 1366 |
-
else:
|
| 1367 |
-
sample.text[condition] = None
|
| 1368 |
-
|
| 1369 |
-
return sample
|
| 1370 |
-
|
| 1371 |
-
|
| 1372 |
-
class DropoutModule(nn.Module):
|
| 1373 |
-
"""Base module for all dropout modules."""
|
| 1374 |
-
def __init__(self, seed: int = 1234):
|
| 1375 |
-
super().__init__()
|
| 1376 |
-
self.rng = torch.Generator()
|
| 1377 |
-
self.rng.manual_seed(seed)
|
| 1378 |
-
|
| 1379 |
-
|
| 1380 |
-
class AttributeDropout(DropoutModule):
|
| 1381 |
-
"""Dropout with a given probability per attribute.
|
| 1382 |
-
This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
|
| 1383 |
-
to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
|
| 1384 |
-
This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
|
| 1385 |
-
must also be dropped.
|
| 1386 |
-
|
| 1387 |
-
Args:
|
| 1388 |
-
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
|
| 1389 |
-
...
|
| 1390 |
-
"genre": 0.1,
|
| 1391 |
-
"artist": 0.5,
|
| 1392 |
-
"wav": 0.25,
|
| 1393 |
-
...
|
| 1394 |
-
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
|
| 1395 |
-
seed (int, optional): Random seed.
|
| 1396 |
-
"""
|
| 1397 |
-
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
|
| 1398 |
-
super().__init__(seed=seed)
|
| 1399 |
-
self.active_on_eval = active_on_eval
|
| 1400 |
-
# construct dict that return the values from p otherwise 0
|
| 1401 |
-
self.p = {}
|
| 1402 |
-
for condition_type, probs in p.items():
|
| 1403 |
-
self.p[condition_type] = defaultdict(lambda: 0, probs)
|
| 1404 |
-
|
| 1405 |
-
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
| 1406 |
-
"""
|
| 1407 |
-
Args:
|
| 1408 |
-
samples (list[ConditioningAttributes]): List of conditions.
|
| 1409 |
-
Returns:
|
| 1410 |
-
list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
|
| 1411 |
-
"""
|
| 1412 |
-
if not self.training and not self.active_on_eval:
|
| 1413 |
-
return samples
|
| 1414 |
-
|
| 1415 |
-
samples = deepcopy(samples)
|
| 1416 |
-
for condition_type, ps in self.p.items(): # for condition types [text, wav, symbolic]
|
| 1417 |
-
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
|
| 1418 |
-
if torch.rand(1, generator=self.rng).item() < p:
|
| 1419 |
-
for sample in samples:
|
| 1420 |
-
dropout_condition(sample, condition_type, condition)
|
| 1421 |
-
return samples
|
| 1422 |
-
|
| 1423 |
-
def __repr__(self):
|
| 1424 |
-
return f"AttributeDropout({dict(self.p)})"
|
| 1425 |
-
|
| 1426 |
-
|
| 1427 |
-
class ClassifierFreeGuidanceDropout(DropoutModule):
|
| 1428 |
-
"""Classifier Free Guidance dropout.
|
| 1429 |
-
All attributes are dropped with the same probability.
|
| 1430 |
-
|
| 1431 |
-
Args:
|
| 1432 |
-
p (float): Probability to apply condition dropout during training.
|
| 1433 |
-
seed (int): Random seed.
|
| 1434 |
-
"""
|
| 1435 |
-
def __init__(self, p: float, seed: int = 1234):
|
| 1436 |
-
super().__init__(seed=seed)
|
| 1437 |
-
self.p = p
|
| 1438 |
-
|
| 1439 |
-
def forward(self, samples: tp.List[ConditioningAttributes],
|
| 1440 |
-
cond_types: tp.List[str] = ["wav", "text"],
|
| 1441 |
-
**kwargs) -> tp.List[ConditioningAttributes]:
|
| 1442 |
-
"""
|
| 1443 |
-
Args:
|
| 1444 |
-
samples (list[ConditioningAttributes]): List of conditions.
|
| 1445 |
-
Returns:
|
| 1446 |
-
list[ConditioningAttributes]: List of conditions after all attributes were set to None.
|
| 1447 |
-
"""
|
| 1448 |
-
if not self.training:
|
| 1449 |
-
return samples
|
| 1450 |
-
|
| 1451 |
-
# decide on which attributes to drop in a batched fashion
|
| 1452 |
-
drop = torch.rand(1, generator=self.rng).item() < self.p
|
| 1453 |
-
if not drop:
|
| 1454 |
-
return samples
|
| 1455 |
-
|
| 1456 |
-
# nullify conditions of all attributes
|
| 1457 |
-
samples = deepcopy(samples)
|
| 1458 |
-
for condition_type in cond_types:
|
| 1459 |
-
for sample in samples:
|
| 1460 |
-
for condition in sample.attributes[condition_type]:
|
| 1461 |
-
dropout_condition(sample, condition_type, condition,
|
| 1462 |
-
**kwargs)
|
| 1463 |
-
return samples
|
| 1464 |
-
|
| 1465 |
-
def __repr__(self):
|
| 1466 |
-
return f"ClassifierFreeGuidanceDropout(p={self.p})"
|
| 1467 |
-
|
| 1468 |
-
|
| 1469 |
-
class ConditioningProvider(nn.Module):
|
| 1470 |
-
"""Prepare and provide conditions given all the supported conditioners.
|
| 1471 |
-
|
| 1472 |
-
Args:
|
| 1473 |
-
conditioners (dict): Dictionary of conditioners.
|
| 1474 |
-
device (torch.device or str, optional): Device for conditioners and output condition types.
|
| 1475 |
-
"""
|
| 1476 |
-
def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
|
| 1477 |
-
super().__init__()
|
| 1478 |
-
self.device = device
|
| 1479 |
-
self.conditioners = nn.ModuleDict(conditioners)
|
| 1480 |
-
|
| 1481 |
-
@property
|
| 1482 |
-
def joint_embed_conditions(self):
|
| 1483 |
-
return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
|
| 1484 |
-
|
| 1485 |
-
@property
|
| 1486 |
-
def has_joint_embed_conditions(self):
|
| 1487 |
-
return len(self.joint_embed_conditions) > 0
|
| 1488 |
-
|
| 1489 |
-
@property
|
| 1490 |
-
def text_conditions(self):
|
| 1491 |
-
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
|
| 1492 |
-
|
| 1493 |
-
@property
|
| 1494 |
-
def wav_conditions(self):
|
| 1495 |
-
return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
|
| 1496 |
-
|
| 1497 |
-
@property
|
| 1498 |
-
def has_wav_condition(self):
|
| 1499 |
-
return len(self.wav_conditions) > 0
|
| 1500 |
-
|
| 1501 |
-
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
| 1502 |
-
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
|
| 1503 |
-
This should be called before starting any real GPU work to avoid synchronization points.
|
| 1504 |
-
This will return a dict matching conditioner names to their arbitrary tokenized representations.
|
| 1505 |
-
|
| 1506 |
-
Args:
|
| 1507 |
-
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
|
| 1508 |
-
text and wav conditions.
|
| 1509 |
-
"""
|
| 1510 |
-
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
|
| 1511 |
-
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
|
| 1512 |
-
f" but types were {set([type(x) for x in inputs])}"
|
| 1513 |
-
)
|
| 1514 |
-
|
| 1515 |
-
output = {}
|
| 1516 |
-
text = self._collate_text(inputs)
|
| 1517 |
-
wavs = self._collate_wavs(inputs)
|
| 1518 |
-
joint_embeds = self._collate_joint_embeds(inputs)
|
| 1519 |
-
|
| 1520 |
-
assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
|
| 1521 |
-
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
| 1522 |
-
f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
|
| 1523 |
-
)
|
| 1524 |
-
|
| 1525 |
-
for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
|
| 1526 |
-
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
| 1527 |
-
return output
|
| 1528 |
-
|
| 1529 |
-
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
| 1530 |
-
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
| 1531 |
-
The output is for example:
|
| 1532 |
-
{
|
| 1533 |
-
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
|
| 1534 |
-
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
|
| 1535 |
-
...
|
| 1536 |
-
}
|
| 1537 |
-
|
| 1538 |
-
Args:
|
| 1539 |
-
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
|
| 1540 |
-
"""
|
| 1541 |
-
output = {}
|
| 1542 |
-
for attribute, inputs in tokenized.items():
|
| 1543 |
-
condition, mask = self.conditioners[attribute](inputs)
|
| 1544 |
-
output[attribute] = (condition, mask)
|
| 1545 |
-
return output
|
| 1546 |
-
|
| 1547 |
-
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
|
| 1548 |
-
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys
|
| 1549 |
-
are the attributes and the values are the aggregated input per attribute.
|
| 1550 |
-
For example:
|
| 1551 |
-
Input:
|
| 1552 |
-
[
|
| 1553 |
-
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
|
| 1554 |
-
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
|
| 1555 |
-
]
|
| 1556 |
-
Output:
|
| 1557 |
-
{
|
| 1558 |
-
"genre": ["Rock", "Hip-hop"],
|
| 1559 |
-
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
|
| 1560 |
-
}
|
| 1561 |
-
|
| 1562 |
-
Args:
|
| 1563 |
-
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
| 1564 |
-
Returns:
|
| 1565 |
-
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
|
| 1566 |
-
"""
|
| 1567 |
-
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
|
| 1568 |
-
texts = [x.text for x in samples]
|
| 1569 |
-
for text in texts:
|
| 1570 |
-
for condition in self.text_conditions:
|
| 1571 |
-
out[condition].append(text[condition])
|
| 1572 |
-
return out
|
| 1573 |
-
|
| 1574 |
-
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
|
| 1575 |
-
"""Generate a dict where the keys are attributes by which we fetch similar wavs,
|
| 1576 |
-
and the values are Tensors of wavs according to said attributes.
|
| 1577 |
-
|
| 1578 |
-
*Note*: by the time the samples reach this function, each sample should have some waveform
|
| 1579 |
-
inside the "wav" attribute. It should be either:
|
| 1580 |
-
1. A real waveform
|
| 1581 |
-
2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
|
| 1582 |
-
3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
|
| 1583 |
-
|
| 1584 |
-
Args:
|
| 1585 |
-
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
| 1586 |
-
Returns:
|
| 1587 |
-
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
|
| 1588 |
-
"""
|
| 1589 |
-
wavs = defaultdict(list)
|
| 1590 |
-
lengths = defaultdict(list)
|
| 1591 |
-
sample_rates = defaultdict(list)
|
| 1592 |
-
paths = defaultdict(list)
|
| 1593 |
-
seek_times = defaultdict(list)
|
| 1594 |
-
out: tp.Dict[str, WavCondition] = {}
|
| 1595 |
-
|
| 1596 |
-
for sample in samples:
|
| 1597 |
-
for attribute in self.wav_conditions:
|
| 1598 |
-
wav, length, sample_rate, path, seek_time = sample.wav[attribute]
|
| 1599 |
-
assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
|
| 1600 |
-
assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
|
| 1601 |
-
# mono-channel conditioning
|
| 1602 |
-
wav = wav.mean(1, keepdim=True) # [1, 1, T]
|
| 1603 |
-
wavs[attribute].append(wav.flatten()) # [T]
|
| 1604 |
-
lengths[attribute].append(length)
|
| 1605 |
-
sample_rates[attribute].extend(sample_rate)
|
| 1606 |
-
paths[attribute].extend(path)
|
| 1607 |
-
seek_times[attribute].extend(seek_time)
|
| 1608 |
-
|
| 1609 |
-
# stack all wavs to a single tensor
|
| 1610 |
-
for attribute in self.wav_conditions:
|
| 1611 |
-
stacked_wav, _ = collate(wavs[attribute], dim=0)
|
| 1612 |
-
out[attribute] = WavCondition(
|
| 1613 |
-
stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
|
| 1614 |
-
paths[attribute], seek_times[attribute])
|
| 1615 |
-
|
| 1616 |
-
return out
|
| 1617 |
-
|
| 1618 |
-
def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
|
| 1619 |
-
"""Generate a dict where the keys are attributes by which we compute joint embeddings,
|
| 1620 |
-
and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
|
| 1621 |
-
|
| 1622 |
-
Args:
|
| 1623 |
-
samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
|
| 1624 |
-
Returns:
|
| 1625 |
-
A dictionary mapping an attribute name to joint embeddings.
|
| 1626 |
-
"""
|
| 1627 |
-
texts = defaultdict(list)
|
| 1628 |
-
wavs = defaultdict(list)
|
| 1629 |
-
lengths = defaultdict(list)
|
| 1630 |
-
sample_rates = defaultdict(list)
|
| 1631 |
-
paths = defaultdict(list)
|
| 1632 |
-
seek_times = defaultdict(list)
|
| 1633 |
-
channels: int = 0
|
| 1634 |
-
|
| 1635 |
-
out = {}
|
| 1636 |
-
for sample in samples:
|
| 1637 |
-
for attribute in self.joint_embed_conditions:
|
| 1638 |
-
wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
|
| 1639 |
-
assert wav.dim() == 3
|
| 1640 |
-
if channels == 0:
|
| 1641 |
-
channels = wav.size(1)
|
| 1642 |
-
else:
|
| 1643 |
-
assert channels == wav.size(1), "not all audio has same number of channels in batch"
|
| 1644 |
-
assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
|
| 1645 |
-
wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
|
| 1646 |
-
wavs[attribute].append(wav)
|
| 1647 |
-
texts[attribute].extend(text)
|
| 1648 |
-
lengths[attribute].append(length)
|
| 1649 |
-
sample_rates[attribute].extend(sample_rate)
|
| 1650 |
-
paths[attribute].extend(path)
|
| 1651 |
-
seek_times[attribute].extend(seek_time)
|
| 1652 |
-
|
| 1653 |
-
for attribute in self.joint_embed_conditions:
|
| 1654 |
-
stacked_texts = texts[attribute]
|
| 1655 |
-
stacked_paths = paths[attribute]
|
| 1656 |
-
stacked_seek_times = seek_times[attribute]
|
| 1657 |
-
stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
|
| 1658 |
-
stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
|
| 1659 |
-
stacked_sample_rates = sample_rates[attribute]
|
| 1660 |
-
stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
|
| 1661 |
-
assert stacked_lengths.size(0) == stacked_wavs.size(0)
|
| 1662 |
-
assert len(stacked_sample_rates) == stacked_wavs.size(0)
|
| 1663 |
-
assert len(stacked_texts) == stacked_wavs.size(0)
|
| 1664 |
-
out[attribute] = JointEmbedCondition(
|
| 1665 |
-
text=stacked_texts, wav=stacked_wavs,
|
| 1666 |
-
length=stacked_lengths, sample_rate=stacked_sample_rates,
|
| 1667 |
-
path=stacked_paths, seek_time=stacked_seek_times)
|
| 1668 |
-
|
| 1669 |
-
return out
|
| 1670 |
-
|
| 1671 |
-
|
| 1672 |
-
class ConditionFuser(StreamingModule):
|
| 1673 |
-
"""Condition fuser handles the logic to combine the different conditions
|
| 1674 |
-
to the actual model input.
|
| 1675 |
-
|
| 1676 |
-
Args:
|
| 1677 |
-
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
|
| 1678 |
-
each condition. For example:
|
| 1679 |
-
{
|
| 1680 |
-
"prepend": ["description"],
|
| 1681 |
-
"sum": ["genre", "bpm"],
|
| 1682 |
-
"cross": ["description"],
|
| 1683 |
-
}
|
| 1684 |
-
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
|
| 1685 |
-
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
|
| 1686 |
-
"""
|
| 1687 |
-
FUSING_METHODS = ["sum", "prepend", "cross", "ignore", "input_interpolate"]
|
| 1688 |
-
|
| 1689 |
-
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
|
| 1690 |
-
cross_attention_pos_emb_scale: float = 1.0):
|
| 1691 |
-
super().__init__()
|
| 1692 |
-
assert all(
|
| 1693 |
-
[k in self.FUSING_METHODS for k in fuse2cond.keys()]
|
| 1694 |
-
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
|
| 1695 |
-
self.cross_attention_pos_emb = cross_attention_pos_emb
|
| 1696 |
-
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
|
| 1697 |
-
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
|
| 1698 |
-
self.cond2fuse: tp.Dict[str, str] = {}
|
| 1699 |
-
for fuse_method, conditions in fuse2cond.items():
|
| 1700 |
-
for condition in conditions:
|
| 1701 |
-
self.cond2fuse[condition] = fuse_method
|
| 1702 |
-
|
| 1703 |
-
def forward(
|
| 1704 |
-
self,
|
| 1705 |
-
input: torch.Tensor,
|
| 1706 |
-
conditions: tp.Dict[str, ConditionType]
|
| 1707 |
-
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
| 1708 |
-
"""Fuse the conditions to the provided model input.
|
| 1709 |
-
|
| 1710 |
-
Args:
|
| 1711 |
-
input (torch.Tensor): Transformer input.
|
| 1712 |
-
conditions (dict[str, ConditionType]): Dict of conditions.
|
| 1713 |
-
Returns:
|
| 1714 |
-
tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
|
| 1715 |
-
after the conditions have been fused. The second output tensor is the tensor
|
| 1716 |
-
used for cross-attention or None if no cross attention inputs exist.
|
| 1717 |
-
"""
|
| 1718 |
-
B, T, _ = input.shape
|
| 1719 |
-
|
| 1720 |
-
if 'offsets' in self._streaming_state:
|
| 1721 |
-
first_step = False
|
| 1722 |
-
offsets = self._streaming_state['offsets']
|
| 1723 |
-
else:
|
| 1724 |
-
first_step = True
|
| 1725 |
-
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
|
| 1726 |
-
|
| 1727 |
-
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
|
| 1728 |
-
f"given conditions contain unknown attributes for fuser, " \
|
| 1729 |
-
f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
|
| 1730 |
-
cross_attention_output = None
|
| 1731 |
-
for cond_type, (cond, cond_mask) in conditions.items():
|
| 1732 |
-
op = self.cond2fuse[cond_type]
|
| 1733 |
-
if op == 'sum':
|
| 1734 |
-
input += cond
|
| 1735 |
-
elif op == 'input_interpolate':
|
| 1736 |
-
cond = einops.rearrange(cond, "b t d -> b d t")
|
| 1737 |
-
cond = F.interpolate(cond, size=input.shape[1])
|
| 1738 |
-
input += einops.rearrange(cond, "b d t -> b t d")
|
| 1739 |
-
elif op == 'prepend':
|
| 1740 |
-
if first_step:
|
| 1741 |
-
input = torch.cat([cond, input], dim=1)
|
| 1742 |
-
elif op == 'cross':
|
| 1743 |
-
if cross_attention_output is not None:
|
| 1744 |
-
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
|
| 1745 |
-
else:
|
| 1746 |
-
cross_attention_output = cond
|
| 1747 |
-
elif op == 'ignore':
|
| 1748 |
-
continue
|
| 1749 |
-
else:
|
| 1750 |
-
raise ValueError(f"unknown op ({op})")
|
| 1751 |
-
|
| 1752 |
-
if self.cross_attention_pos_emb and cross_attention_output is not None:
|
| 1753 |
-
positions = torch.arange(
|
| 1754 |
-
cross_attention_output.shape[1],
|
| 1755 |
-
device=cross_attention_output.device
|
| 1756 |
-
).view(1, -1, 1)
|
| 1757 |
-
pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
|
| 1758 |
-
cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
|
| 1759 |
-
|
| 1760 |
-
if self._is_streaming:
|
| 1761 |
-
self._streaming_state['offsets'] = offsets + T
|
| 1762 |
-
|
| 1763 |
-
return input, cross_attention_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/conv.py
DELETED
|
@@ -1,245 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import math
|
| 8 |
-
import typing as tp
|
| 9 |
-
import warnings
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
from torch import nn
|
| 13 |
-
from torch.nn import functional as F
|
| 14 |
-
from torch.nn.utils.parametrizations import spectral_norm, weight_norm
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
| 18 |
-
'time_group_norm'])
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
|
| 22 |
-
assert norm in CONV_NORMALIZATIONS
|
| 23 |
-
if norm == 'weight_norm':
|
| 24 |
-
return weight_norm(module)
|
| 25 |
-
elif norm == 'spectral_norm':
|
| 26 |
-
return spectral_norm(module)
|
| 27 |
-
else:
|
| 28 |
-
# We already check was in CONV_NORMALIZATION, so any other choice
|
| 29 |
-
# doesn't need reparametrization.
|
| 30 |
-
return module
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
|
| 34 |
-
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
| 35 |
-
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
| 36 |
-
"""
|
| 37 |
-
assert norm in CONV_NORMALIZATIONS
|
| 38 |
-
if norm == 'time_group_norm':
|
| 39 |
-
if causal:
|
| 40 |
-
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
| 41 |
-
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 42 |
-
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
| 43 |
-
else:
|
| 44 |
-
return nn.Identity()
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
| 48 |
-
padding_total: int = 0) -> int:
|
| 49 |
-
"""See `pad_for_conv1d`.
|
| 50 |
-
"""
|
| 51 |
-
length = x.shape[-1]
|
| 52 |
-
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 53 |
-
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 54 |
-
return ideal_length - length
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
| 58 |
-
"""Pad for a convolution to make sure that the last window is full.
|
| 59 |
-
Extra padding is added at the end. This is required to ensure that we can rebuild
|
| 60 |
-
an output of the same length, as otherwise, even with padding, some time steps
|
| 61 |
-
might get removed.
|
| 62 |
-
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
| 63 |
-
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
| 64 |
-
1 2 3 # (output frames of a convolution, last 0 is never used)
|
| 65 |
-
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
| 66 |
-
1 2 3 4 # once you removed padding, we are missing one time step !
|
| 67 |
-
"""
|
| 68 |
-
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 69 |
-
return F.pad(x, (0, extra_padding))
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
|
| 73 |
-
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
| 74 |
-
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
| 75 |
-
"""
|
| 76 |
-
length = x.shape[-1]
|
| 77 |
-
padding_left, padding_right = paddings
|
| 78 |
-
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 79 |
-
if mode == 'reflect':
|
| 80 |
-
max_pad = max(padding_left, padding_right)
|
| 81 |
-
extra_pad = 0
|
| 82 |
-
if length <= max_pad:
|
| 83 |
-
extra_pad = max_pad - length + 1
|
| 84 |
-
x = F.pad(x, (0, extra_pad))
|
| 85 |
-
padded = F.pad(x, paddings, mode, value)
|
| 86 |
-
end = padded.shape[-1] - extra_pad
|
| 87 |
-
return padded[..., :end]
|
| 88 |
-
else:
|
| 89 |
-
return F.pad(x, paddings, mode, value)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
| 93 |
-
"""Remove padding from x, handling properly zero padding. Only for 1d!
|
| 94 |
-
"""
|
| 95 |
-
padding_left, padding_right = paddings
|
| 96 |
-
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 97 |
-
assert (padding_left + padding_right) <= x.shape[-1]
|
| 98 |
-
end = x.shape[-1] - padding_right
|
| 99 |
-
return x[..., padding_left: end]
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
class NormConv1d(nn.Module):
|
| 103 |
-
"""Wrapper around Conv1d and normalization applied to this conv
|
| 104 |
-
to provide a uniform interface across normalization approaches.
|
| 105 |
-
"""
|
| 106 |
-
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| 107 |
-
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 108 |
-
super().__init__()
|
| 109 |
-
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
| 110 |
-
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
| 111 |
-
self.norm_type = norm
|
| 112 |
-
|
| 113 |
-
def forward(self, x):
|
| 114 |
-
x = self.conv(x)
|
| 115 |
-
x = self.norm(x)
|
| 116 |
-
return x
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
class NormConv2d(nn.Module):
|
| 120 |
-
"""Wrapper around Conv2d and normalization applied to this conv
|
| 121 |
-
to provide a uniform interface across normalization approaches.
|
| 122 |
-
"""
|
| 123 |
-
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 124 |
-
super().__init__()
|
| 125 |
-
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
| 126 |
-
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
| 127 |
-
self.norm_type = norm
|
| 128 |
-
|
| 129 |
-
def forward(self, x):
|
| 130 |
-
x = self.conv(x)
|
| 131 |
-
x = self.norm(x)
|
| 132 |
-
return x
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
class NormConvTranspose1d(nn.Module):
|
| 136 |
-
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
| 137 |
-
to provide a uniform interface across normalization approaches.
|
| 138 |
-
"""
|
| 139 |
-
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| 140 |
-
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 141 |
-
super().__init__()
|
| 142 |
-
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
| 143 |
-
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
| 144 |
-
self.norm_type = norm
|
| 145 |
-
|
| 146 |
-
def forward(self, x):
|
| 147 |
-
x = self.convtr(x)
|
| 148 |
-
x = self.norm(x)
|
| 149 |
-
return x
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
class NormConvTranspose2d(nn.Module):
|
| 153 |
-
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
| 154 |
-
to provide a uniform interface across normalization approaches.
|
| 155 |
-
"""
|
| 156 |
-
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 157 |
-
super().__init__()
|
| 158 |
-
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
| 159 |
-
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
| 160 |
-
|
| 161 |
-
def forward(self, x):
|
| 162 |
-
x = self.convtr(x)
|
| 163 |
-
x = self.norm(x)
|
| 164 |
-
return x
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
class StreamableConv1d(nn.Module):
|
| 168 |
-
"""Conv1d with some builtin handling of asymmetric or causal padding
|
| 169 |
-
and normalization.
|
| 170 |
-
"""
|
| 171 |
-
def __init__(self, in_channels: int, out_channels: int,
|
| 172 |
-
kernel_size: int, stride: int = 1, dilation: int = 1,
|
| 173 |
-
groups: int = 1, bias: bool = True, causal: bool = False,
|
| 174 |
-
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 175 |
-
pad_mode: str = 'reflect'):
|
| 176 |
-
super().__init__()
|
| 177 |
-
# warn user on unusual setup between dilation and stride
|
| 178 |
-
if stride > 1 and dilation > 1:
|
| 179 |
-
warnings.warn('StreamableConv1d has been initialized with stride > 1 and dilation > 1'
|
| 180 |
-
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
|
| 181 |
-
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
| 182 |
-
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
| 183 |
-
norm=norm, norm_kwargs=norm_kwargs)
|
| 184 |
-
self.causal = causal
|
| 185 |
-
self.pad_mode = pad_mode
|
| 186 |
-
|
| 187 |
-
def forward(self, x):
|
| 188 |
-
B, C, T = x.shape
|
| 189 |
-
kernel_size = self.conv.conv.kernel_size[0]
|
| 190 |
-
stride = self.conv.conv.stride[0]
|
| 191 |
-
dilation = self.conv.conv.dilation[0]
|
| 192 |
-
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
| 193 |
-
padding_total = kernel_size - stride
|
| 194 |
-
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 195 |
-
if self.causal:
|
| 196 |
-
# Left padding for causal
|
| 197 |
-
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
| 198 |
-
else:
|
| 199 |
-
# Asymmetric padding required for odd strides
|
| 200 |
-
padding_right = padding_total // 2
|
| 201 |
-
padding_left = padding_total - padding_right
|
| 202 |
-
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
| 203 |
-
return self.conv(x)
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
class StreamableConvTranspose1d(nn.Module):
|
| 207 |
-
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
| 208 |
-
and normalization.
|
| 209 |
-
"""
|
| 210 |
-
def __init__(self, in_channels: int, out_channels: int,
|
| 211 |
-
kernel_size: int, stride: int = 1, causal: bool = False,
|
| 212 |
-
norm: str = 'none', trim_right_ratio: float = 1.,
|
| 213 |
-
norm_kwargs: tp.Dict[str, tp.Any] = {}):
|
| 214 |
-
super().__init__()
|
| 215 |
-
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
| 216 |
-
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
| 217 |
-
self.causal = causal
|
| 218 |
-
self.trim_right_ratio = trim_right_ratio
|
| 219 |
-
assert self.causal or self.trim_right_ratio == 1., \
|
| 220 |
-
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
| 221 |
-
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
| 222 |
-
|
| 223 |
-
def forward(self, x):
|
| 224 |
-
kernel_size = self.convtr.convtr.kernel_size[0]
|
| 225 |
-
stride = self.convtr.convtr.stride[0]
|
| 226 |
-
padding_total = kernel_size - stride
|
| 227 |
-
|
| 228 |
-
y = self.convtr(x)
|
| 229 |
-
|
| 230 |
-
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
| 231 |
-
# removed at the very end, when keeping only the right length for the output,
|
| 232 |
-
# as removing it here would require also passing the length at the matching layer
|
| 233 |
-
# in the encoder.
|
| 234 |
-
if self.causal:
|
| 235 |
-
# Trim the padding on the right according to the specified ratio
|
| 236 |
-
# if trim_right_ratio = 1.0, trim everything from right
|
| 237 |
-
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
| 238 |
-
padding_left = padding_total - padding_right
|
| 239 |
-
y = unpad1d(y, (padding_left, padding_right))
|
| 240 |
-
else:
|
| 241 |
-
# Asymmetric padding required for odd strides
|
| 242 |
-
padding_right = padding_total // 2
|
| 243 |
-
padding_left = padding_total - padding_right
|
| 244 |
-
y = unpad1d(y, (padding_left, padding_right))
|
| 245 |
-
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/diffusion_schedule.py
DELETED
|
@@ -1,272 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
from collections import namedtuple
|
| 12 |
-
import random
|
| 13 |
-
import typing as tp
|
| 14 |
-
import julius
|
| 15 |
-
import torch
|
| 16 |
-
|
| 17 |
-
TrainingItem = namedtuple("TrainingItem", "noisy noise step")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def betas_from_alpha_bar(alpha_bar):
|
| 21 |
-
alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
|
| 22 |
-
return 1 - alphas
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class SampleProcessor(torch.nn.Module):
|
| 26 |
-
def project_sample(self, x: torch.Tensor):
|
| 27 |
-
"""Project the original sample to the 'space' where the diffusion will happen."""
|
| 28 |
-
return x
|
| 29 |
-
|
| 30 |
-
def return_sample(self, z: torch.Tensor):
|
| 31 |
-
"""Project back from diffusion space to the actual sample space."""
|
| 32 |
-
return z
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
class MultiBandProcessor(SampleProcessor):
|
| 36 |
-
"""
|
| 37 |
-
MultiBand sample processor. The input audio is splitted across
|
| 38 |
-
frequency bands evenly distributed in mel-scale.
|
| 39 |
-
|
| 40 |
-
Each band will be rescaled to match the power distribution
|
| 41 |
-
of Gaussian noise in that band, using online metrics
|
| 42 |
-
computed on the first few samples.
|
| 43 |
-
|
| 44 |
-
Args:
|
| 45 |
-
n_bands (int): Number of mel-bands to split the signal over.
|
| 46 |
-
sample_rate (int): Sample rate of the audio.
|
| 47 |
-
num_samples (int): Number of samples to use to fit the rescaling
|
| 48 |
-
for each band. The processor won't be stable
|
| 49 |
-
until it has seen that many samples.
|
| 50 |
-
power_std (float or list/tensor): The rescaling factor computed to match the
|
| 51 |
-
power of Gaussian noise in each band is taken to
|
| 52 |
-
that power, i.e. `1.` means full correction of the energy
|
| 53 |
-
in each band, and values less than `1` means only partial
|
| 54 |
-
correction. Can be used to balance the relative importance
|
| 55 |
-
of low vs. high freq in typical audio signals.
|
| 56 |
-
"""
|
| 57 |
-
def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
|
| 58 |
-
num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
|
| 59 |
-
super().__init__()
|
| 60 |
-
self.n_bands = n_bands
|
| 61 |
-
self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
|
| 62 |
-
self.num_samples = num_samples
|
| 63 |
-
self.power_std = power_std
|
| 64 |
-
if isinstance(power_std, list):
|
| 65 |
-
assert len(power_std) == n_bands
|
| 66 |
-
power_std = torch.tensor(power_std)
|
| 67 |
-
self.register_buffer('counts', torch.zeros(1))
|
| 68 |
-
self.register_buffer('sum_x', torch.zeros(n_bands))
|
| 69 |
-
self.register_buffer('sum_x2', torch.zeros(n_bands))
|
| 70 |
-
self.register_buffer('sum_target_x2', torch.zeros(n_bands))
|
| 71 |
-
self.counts: torch.Tensor
|
| 72 |
-
self.sum_x: torch.Tensor
|
| 73 |
-
self.sum_x2: torch.Tensor
|
| 74 |
-
self.sum_target_x2: torch.Tensor
|
| 75 |
-
|
| 76 |
-
@property
|
| 77 |
-
def mean(self):
|
| 78 |
-
mean = self.sum_x / self.counts
|
| 79 |
-
return mean
|
| 80 |
-
|
| 81 |
-
@property
|
| 82 |
-
def std(self):
|
| 83 |
-
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
| 84 |
-
return std
|
| 85 |
-
|
| 86 |
-
@property
|
| 87 |
-
def target_std(self):
|
| 88 |
-
target_std = self.sum_target_x2 / self.counts
|
| 89 |
-
return target_std
|
| 90 |
-
|
| 91 |
-
def project_sample(self, x: torch.Tensor):
|
| 92 |
-
assert x.dim() == 3
|
| 93 |
-
bands = self.split_bands(x)
|
| 94 |
-
if self.counts.item() < self.num_samples:
|
| 95 |
-
ref_bands = self.split_bands(torch.randn_like(x))
|
| 96 |
-
self.counts += len(x)
|
| 97 |
-
self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
|
| 98 |
-
self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
|
| 99 |
-
self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
|
| 100 |
-
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
| 101 |
-
bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
|
| 102 |
-
return bands.sum(dim=0)
|
| 103 |
-
|
| 104 |
-
def return_sample(self, x: torch.Tensor):
|
| 105 |
-
assert x.dim() == 3
|
| 106 |
-
bands = self.split_bands(x)
|
| 107 |
-
rescale = (self.std / self.target_std) ** self.power_std
|
| 108 |
-
bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
|
| 109 |
-
return bands.sum(dim=0)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class NoiseSchedule:
|
| 113 |
-
"""Noise schedule for diffusion.
|
| 114 |
-
|
| 115 |
-
Args:
|
| 116 |
-
beta_t0 (float): Variance of the first diffusion step.
|
| 117 |
-
beta_t1 (float): Variance of the last diffusion step.
|
| 118 |
-
beta_exp (float): Power schedule exponent
|
| 119 |
-
num_steps (int): Number of diffusion step.
|
| 120 |
-
variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
|
| 121 |
-
clip (float): clipping value for the denoising steps
|
| 122 |
-
rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
|
| 123 |
-
repartition (str): shape of the schedule only power schedule is supported
|
| 124 |
-
sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
|
| 125 |
-
noise_scale (float): Scaling factor for the noise
|
| 126 |
-
"""
|
| 127 |
-
def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
|
| 128 |
-
clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
|
| 129 |
-
repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
|
| 130 |
-
sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
|
| 131 |
-
|
| 132 |
-
self.beta_t0 = beta_t0
|
| 133 |
-
self.beta_t1 = beta_t1
|
| 134 |
-
self.variance = variance
|
| 135 |
-
self.num_steps = num_steps
|
| 136 |
-
self.clip = clip
|
| 137 |
-
self.sample_processor = sample_processor
|
| 138 |
-
self.rescale = rescale
|
| 139 |
-
self.n_bands = n_bands
|
| 140 |
-
self.noise_scale = noise_scale
|
| 141 |
-
assert n_bands is None
|
| 142 |
-
if repartition == "power":
|
| 143 |
-
self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
|
| 144 |
-
device=device, dtype=torch.float) ** beta_exp
|
| 145 |
-
else:
|
| 146 |
-
raise RuntimeError('Not implemented')
|
| 147 |
-
self.rng = random.Random(1234)
|
| 148 |
-
|
| 149 |
-
def get_beta(self, step: tp.Union[int, torch.Tensor]):
|
| 150 |
-
if self.n_bands is None:
|
| 151 |
-
return self.betas[step]
|
| 152 |
-
else:
|
| 153 |
-
return self.betas[:, step] # [n_bands, len(step)]
|
| 154 |
-
|
| 155 |
-
def get_initial_noise(self, x: torch.Tensor):
|
| 156 |
-
if self.n_bands is None:
|
| 157 |
-
return torch.randn_like(x)
|
| 158 |
-
return torch.randn((x.size(0), self.n_bands, x.size(2)))
|
| 159 |
-
|
| 160 |
-
def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
|
| 161 |
-
"""Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
|
| 162 |
-
if step is None:
|
| 163 |
-
return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
|
| 164 |
-
if type(step) is int:
|
| 165 |
-
return (1 - self.betas[:step + 1]).prod()
|
| 166 |
-
else:
|
| 167 |
-
return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
|
| 168 |
-
|
| 169 |
-
def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
|
| 170 |
-
"""Create a noisy data item for diffusion model training:
|
| 171 |
-
|
| 172 |
-
Args:
|
| 173 |
-
x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
|
| 174 |
-
tensor_step (bool): If tensor_step = false, only one step t is sample,
|
| 175 |
-
the whole batch is diffused to the same step and t is int.
|
| 176 |
-
If tensor_step = true, t is a tensor of size (x.size(0),)
|
| 177 |
-
every element of the batch is diffused to a independently sampled.
|
| 178 |
-
"""
|
| 179 |
-
step: tp.Union[int, torch.Tensor]
|
| 180 |
-
if tensor_step:
|
| 181 |
-
bs = x.size(0)
|
| 182 |
-
step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
|
| 183 |
-
else:
|
| 184 |
-
step = self.rng.randrange(self.num_steps)
|
| 185 |
-
alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
|
| 186 |
-
|
| 187 |
-
x = self.sample_processor.project_sample(x)
|
| 188 |
-
noise = torch.randn_like(x)
|
| 189 |
-
noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
|
| 190 |
-
return TrainingItem(noisy, noise, step)
|
| 191 |
-
|
| 192 |
-
def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
|
| 193 |
-
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
|
| 194 |
-
"""Full ddpm reverse process.
|
| 195 |
-
|
| 196 |
-
Args:
|
| 197 |
-
model (nn.Module): Diffusion model.
|
| 198 |
-
initial (tensor): Initial Noise.
|
| 199 |
-
condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
|
| 200 |
-
return_list (bool): Whether to return the whole process or only the sampled point.
|
| 201 |
-
"""
|
| 202 |
-
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
|
| 203 |
-
current = initial
|
| 204 |
-
iterates = [initial]
|
| 205 |
-
for step in range(self.num_steps)[::-1]:
|
| 206 |
-
with torch.no_grad():
|
| 207 |
-
estimate = model(current, step, condition=condition).sample
|
| 208 |
-
alpha = 1 - self.betas[step]
|
| 209 |
-
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
|
| 210 |
-
previous_alpha_bar = self.get_alpha_bar(step=step - 1)
|
| 211 |
-
if step == 0:
|
| 212 |
-
sigma2 = 0
|
| 213 |
-
elif self.variance == 'beta':
|
| 214 |
-
sigma2 = 1 - alpha
|
| 215 |
-
elif self.variance == 'beta_tilde':
|
| 216 |
-
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
|
| 217 |
-
elif self.variance == 'none':
|
| 218 |
-
sigma2 = 0
|
| 219 |
-
else:
|
| 220 |
-
raise ValueError(f'Invalid variance type {self.variance}')
|
| 221 |
-
|
| 222 |
-
if sigma2 > 0:
|
| 223 |
-
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
|
| 224 |
-
if self.clip:
|
| 225 |
-
previous = previous.clamp(-self.clip, self.clip)
|
| 226 |
-
current = previous
|
| 227 |
-
alpha_bar = previous_alpha_bar
|
| 228 |
-
if step == 0:
|
| 229 |
-
previous *= self.rescale
|
| 230 |
-
if return_list:
|
| 231 |
-
iterates.append(previous.cpu())
|
| 232 |
-
|
| 233 |
-
if return_list:
|
| 234 |
-
return iterates
|
| 235 |
-
else:
|
| 236 |
-
return self.sample_processor.return_sample(previous)
|
| 237 |
-
|
| 238 |
-
def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
|
| 239 |
-
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
|
| 240 |
-
"""Reverse process that only goes through Markov chain states in step_list."""
|
| 241 |
-
if step_list is None:
|
| 242 |
-
step_list = list(range(1000))[::-50] + [0]
|
| 243 |
-
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
|
| 244 |
-
alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
|
| 245 |
-
betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
|
| 246 |
-
current = initial * self.noise_scale
|
| 247 |
-
iterates = [current]
|
| 248 |
-
for idx, step in enumerate(step_list[:-1]):
|
| 249 |
-
with torch.no_grad():
|
| 250 |
-
estimate = model(current, step, condition=condition).sample * self.noise_scale
|
| 251 |
-
alpha = 1 - betas_subsampled[-1 - idx]
|
| 252 |
-
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
|
| 253 |
-
previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
|
| 254 |
-
if step == step_list[-2]:
|
| 255 |
-
sigma2 = 0
|
| 256 |
-
previous_alpha_bar = torch.tensor(1.0)
|
| 257 |
-
else:
|
| 258 |
-
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
|
| 259 |
-
if sigma2 > 0:
|
| 260 |
-
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
|
| 261 |
-
if self.clip:
|
| 262 |
-
previous = previous.clamp(-self.clip, self.clip)
|
| 263 |
-
current = previous
|
| 264 |
-
alpha_bar = previous_alpha_bar
|
| 265 |
-
if step == 0:
|
| 266 |
-
previous *= self.rescale
|
| 267 |
-
if return_list:
|
| 268 |
-
iterates.append(previous.cpu())
|
| 269 |
-
if return_list:
|
| 270 |
-
return iterates
|
| 271 |
-
else:
|
| 272 |
-
return self.sample_processor.return_sample(previous)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/jasco_conditioners.py
DELETED
|
@@ -1,300 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import typing as tp
|
| 3 |
-
from itertools import chain
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from torch import nn
|
| 6 |
-
from .conditioners import (ConditioningAttributes, BaseConditioner, ConditionType,
|
| 7 |
-
ConditioningProvider, JascoCondConst,
|
| 8 |
-
WaveformConditioner, WavCondition, SymbolicCondition)
|
| 9 |
-
from ..data.audio import audio_read
|
| 10 |
-
from ..data.audio_utils import convert_audio
|
| 11 |
-
from ..utils.autocast import TorchAutocast
|
| 12 |
-
from ..utils.cache import EmbeddingCache
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class MelodyConditioner(BaseConditioner):
|
| 16 |
-
"""
|
| 17 |
-
A conditioner that handles melody conditioning from pre-computed salience matrix.
|
| 18 |
-
Attributes:
|
| 19 |
-
card (int): The cardinality of the melody matrix.
|
| 20 |
-
out_dim (int): The dimensionality of the output projection.
|
| 21 |
-
device (Union[torch.device, str]): The device on which the embeddings are stored.
|
| 22 |
-
"""
|
| 23 |
-
def __init__(self, card: int, out_dim: int, device: tp.Union[torch.device, str] = 'cpu', **kwargs):
|
| 24 |
-
super().__init__(dim=card, output_dim=out_dim)
|
| 25 |
-
self.device = device
|
| 26 |
-
|
| 27 |
-
def tokenize(self, x: SymbolicCondition) -> SymbolicCondition:
|
| 28 |
-
return SymbolicCondition(melody=x.melody.to(self.device)) # type: ignore
|
| 29 |
-
|
| 30 |
-
def forward(self, x: SymbolicCondition) -> ConditionType:
|
| 31 |
-
embeds = self.output_proj(x.melody.permute(0, 2, 1)) # type: ignore
|
| 32 |
-
mask = torch.ones_like(embeds[..., 0])
|
| 33 |
-
return embeds, mask
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
class ChordsEmbConditioner(BaseConditioner):
|
| 37 |
-
"""
|
| 38 |
-
A conditioner that embeds chord symbols into a continuous vector space.
|
| 39 |
-
Attributes:
|
| 40 |
-
card (int): The cardinality of the chord vocabulary.
|
| 41 |
-
out_dim (int): The dimensionality of the output embeddings.
|
| 42 |
-
device (Union[torch.device, str]): The device on which the embeddings are stored.
|
| 43 |
-
"""
|
| 44 |
-
def __init__(self, card: int, out_dim: int, device: tp.Union[torch.device, str] = 'cpu', **kwargs):
|
| 45 |
-
vocab_size = card + 1 # card + 1 - for null chord used during dropout
|
| 46 |
-
super().__init__(dim=vocab_size, output_dim=-1) # out_dim=-1 to avoid another projection
|
| 47 |
-
self.emb = nn.Embedding(vocab_size, out_dim, device=device)
|
| 48 |
-
self.device = device
|
| 49 |
-
|
| 50 |
-
def tokenize(self, x: SymbolicCondition) -> SymbolicCondition:
|
| 51 |
-
return SymbolicCondition(frame_chords=x.frame_chords.to(self.device)) # type: ignore
|
| 52 |
-
|
| 53 |
-
def forward(self, x: SymbolicCondition) -> ConditionType:
|
| 54 |
-
embeds = self.emb(x.frame_chords)
|
| 55 |
-
mask = torch.ones_like(embeds[..., 0])
|
| 56 |
-
return embeds, mask
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class DrumsConditioner(WaveformConditioner):
|
| 60 |
-
def __init__(self, out_dim: int, sample_rate: int, blurring_factor: int = 3,
|
| 61 |
-
cache_path: tp.Optional[tp.Union[str, Path]] = None,
|
| 62 |
-
compression_model_latent_dim: int = 128,
|
| 63 |
-
compression_model_framerate: float = 50,
|
| 64 |
-
segment_duration: float = 10.0,
|
| 65 |
-
device: tp.Union[torch.device, str] = 'cpu',
|
| 66 |
-
**kwargs):
|
| 67 |
-
"""Drum condition conditioner
|
| 68 |
-
|
| 69 |
-
Args:
|
| 70 |
-
out_dim (int): _description_
|
| 71 |
-
sample_rate (int): _description_
|
| 72 |
-
blurring_factor (int, optional): _description_. Defaults to 3.
|
| 73 |
-
cache_path (tp.Optional[tp.Union[str, Path]], optional): path to precomputed cache. Defaults to None.
|
| 74 |
-
compression_model_latent_dim (int, optional): latent dimensino. Defaults to 128.
|
| 75 |
-
compression_model_framerate (float, optional): frame rate of the representation model. Defaults to 50.
|
| 76 |
-
segment_duration (float, optional): duration in sec for each audio segment. Defaults to 10.0.
|
| 77 |
-
device (tp.Union[torch.device, str], optional): device. Defaults to 'cpu'.
|
| 78 |
-
"""
|
| 79 |
-
from demucs import pretrained
|
| 80 |
-
self.sample_rate = sample_rate
|
| 81 |
-
self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
|
| 82 |
-
stem_sources: list = self.demucs.sources # type: ignore
|
| 83 |
-
self.stem_idx = stem_sources.index('drums')
|
| 84 |
-
self.compression_model = None
|
| 85 |
-
self.latent_dim = compression_model_latent_dim
|
| 86 |
-
super().__init__(dim=self.latent_dim, output_dim=out_dim, device=device)
|
| 87 |
-
self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
|
| 88 |
-
self._use_masking = False
|
| 89 |
-
self.blurring_factor = blurring_factor
|
| 90 |
-
self.seq_len = int(segment_duration * compression_model_framerate)
|
| 91 |
-
self.cache = None # If you wish to train with EmbeddingCache, call self.create_embedding_cache(cache_path)
|
| 92 |
-
|
| 93 |
-
def create_embedding_cache(self, cache_path):
|
| 94 |
-
if cache_path is not None:
|
| 95 |
-
self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
| 96 |
-
compute_embed_fn=self._calc_coarse_drum_codes_for_cache,
|
| 97 |
-
extract_embed_fn=self._load_drum_codes_chunk)
|
| 98 |
-
|
| 99 |
-
@torch.no_grad()
|
| 100 |
-
def _get_drums_stem(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
| 101 |
-
"""Get parts of the wav that holds the drums, extracting the main stems from the wav."""
|
| 102 |
-
from demucs.apply import apply_model
|
| 103 |
-
from demucs.audio import convert_audio
|
| 104 |
-
with self.autocast:
|
| 105 |
-
wav = convert_audio(
|
| 106 |
-
wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
|
| 107 |
-
stems = apply_model(self.demucs, wav, device=self.device)
|
| 108 |
-
drum_stem = stems[:, self.stem_idx] # extract relevant stems for drums conditioning
|
| 109 |
-
return convert_audio(drum_stem, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
|
| 110 |
-
|
| 111 |
-
def _temporal_blur(self, z: torch.Tensor):
|
| 112 |
-
# z: (B, T, C)
|
| 113 |
-
B, T, C = z.shape
|
| 114 |
-
if T % self.blurring_factor != 0:
|
| 115 |
-
# pad with reflect for T % self.temporal_blurring on the right in dim=1
|
| 116 |
-
pad_val = self.blurring_factor - T % self.blurring_factor
|
| 117 |
-
z = torch.nn.functional.pad(z, (0, 0, 0, pad_val), mode='reflect')
|
| 118 |
-
z = z.reshape(B, -1, self.blurring_factor, C).sum(dim=2) / self.blurring_factor
|
| 119 |
-
z = z.unsqueeze(2).repeat(1, 1, self.blurring_factor, 1).reshape(B, -1, C)
|
| 120 |
-
z = z[:, :T]
|
| 121 |
-
assert z.shape == (B, T, C)
|
| 122 |
-
return z
|
| 123 |
-
|
| 124 |
-
@torch.no_grad()
|
| 125 |
-
def _extract_coarse_drum_codes(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
| 126 |
-
assert self.compression_model is not None
|
| 127 |
-
|
| 128 |
-
# stem separation of drums
|
| 129 |
-
drums = self._get_drums_stem(wav, sample_rate)
|
| 130 |
-
|
| 131 |
-
# continuous encoding with compression model
|
| 132 |
-
latents = self.compression_model.model.encoder(drums)
|
| 133 |
-
|
| 134 |
-
# quantization to coarsest codebook
|
| 135 |
-
coarsest_quantizer = self.compression_model.model.quantizer.layers[0]
|
| 136 |
-
drums = coarsest_quantizer.encode(latents).to(torch.int16)
|
| 137 |
-
return drums
|
| 138 |
-
|
| 139 |
-
@torch.no_grad()
|
| 140 |
-
def _calc_coarse_drum_codes_for_cache(self, path: tp.Union[str, Path],
|
| 141 |
-
x: WavCondition, idx: int,
|
| 142 |
-
max_duration_to_process: float = 600) -> torch.Tensor:
|
| 143 |
-
"""Extract blurred drum latents from the whole audio waveform at the given path."""
|
| 144 |
-
wav, sr = audio_read(path)
|
| 145 |
-
wav = wav[None].to(self.device)
|
| 146 |
-
wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
|
| 147 |
-
|
| 148 |
-
max_frames_to_process = int(max_duration_to_process * self.sample_rate)
|
| 149 |
-
if wav.shape[-1] > max_frames_to_process:
|
| 150 |
-
# process very long tracks in chunks
|
| 151 |
-
start = 0
|
| 152 |
-
codes = []
|
| 153 |
-
while start < wav.shape[-1] - 1:
|
| 154 |
-
wav_chunk = wav[..., start: start + max_frames_to_process]
|
| 155 |
-
codes.append(self._extract_coarse_drum_codes(wav_chunk, self.sample_rate)[0])
|
| 156 |
-
start += max_frames_to_process
|
| 157 |
-
return torch.cat(codes)
|
| 158 |
-
|
| 159 |
-
return self._extract_coarse_drum_codes(wav, self.sample_rate)[0]
|
| 160 |
-
|
| 161 |
-
def _load_drum_codes_chunk(self, full_coarse_drum_codes: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
|
| 162 |
-
"""Extract a chunk of coarse drum codes from the full coarse drum codes derived from the full waveform."""
|
| 163 |
-
wav_length = x.wav.shape[-1]
|
| 164 |
-
seek_time = x.seek_time[idx]
|
| 165 |
-
assert seek_time is not None, (
|
| 166 |
-
"WavCondition seek_time is required "
|
| 167 |
-
"when extracting chunks from pre-computed drum codes.")
|
| 168 |
-
assert self.compression_model is not None
|
| 169 |
-
frame_rate = self.compression_model.frame_rate
|
| 170 |
-
target_length = int(frame_rate * wav_length / self.sample_rate)
|
| 171 |
-
target_length = max(target_length, self.seq_len)
|
| 172 |
-
index = int(frame_rate * seek_time)
|
| 173 |
-
out = full_coarse_drum_codes[index: index + target_length]
|
| 174 |
-
# pad
|
| 175 |
-
out = torch.cat((out, torch.zeros(target_length - out.shape[0], dtype=out.dtype, device=out.device)))
|
| 176 |
-
return out.to(self.device)
|
| 177 |
-
|
| 178 |
-
@torch.no_grad()
|
| 179 |
-
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
| 180 |
-
bs = x.wav.shape[0]
|
| 181 |
-
if x.wav.shape[-1] <= 1:
|
| 182 |
-
# null condition
|
| 183 |
-
return torch.zeros((bs, self.seq_len, self.latent_dim), device=x.wav.device, dtype=x.wav.dtype)
|
| 184 |
-
|
| 185 |
-
# extract coarse drum codes
|
| 186 |
-
no_undefined_paths = all(p is not None for p in x.path)
|
| 187 |
-
no_nullified_cond = x.wav.shape[-1] > 1
|
| 188 |
-
if self.cache is not None and no_undefined_paths and no_nullified_cond:
|
| 189 |
-
paths = [Path(p) for p in x.path if p is not None]
|
| 190 |
-
codes = self.cache.get_embed_from_cache(paths, x)
|
| 191 |
-
else:
|
| 192 |
-
assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
|
| 193 |
-
codes = self._extract_coarse_drum_codes(x.wav, x.sample_rate[0])
|
| 194 |
-
|
| 195 |
-
assert self.compression_model is not None
|
| 196 |
-
# decode back to the continuous representation of compression model
|
| 197 |
-
codes = codes.unsqueeze(1).permute(1, 0, 2) # (B, T) -> (1, B, T)
|
| 198 |
-
codes = codes.to(torch.int64)
|
| 199 |
-
latents = self.compression_model.model.quantizer.decode(codes)
|
| 200 |
-
|
| 201 |
-
latents = latents.permute(0, 2, 1) # [B, C, T] -> [B, T, C]
|
| 202 |
-
|
| 203 |
-
# temporal blurring
|
| 204 |
-
return self._temporal_blur(latents)
|
| 205 |
-
|
| 206 |
-
def tokenize(self, x: WavCondition) -> WavCondition:
|
| 207 |
-
"""Apply WavConditioner tokenization and populate cache if needed."""
|
| 208 |
-
x = super().tokenize(x)
|
| 209 |
-
no_undefined_paths = all(p is not None for p in x.path)
|
| 210 |
-
if self.cache is not None and no_undefined_paths:
|
| 211 |
-
paths = [Path(p) for p in x.path if p is not None]
|
| 212 |
-
self.cache.populate_embed_cache(paths, x)
|
| 213 |
-
return x
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
class JascoConditioningProvider(ConditioningProvider):
|
| 217 |
-
"""
|
| 218 |
-
A cond-provider that manages and tokenizes various types of conditioning attributes for Jasco models.
|
| 219 |
-
Attributes:
|
| 220 |
-
chords_card (int): The cardinality of the chord vocabulary.
|
| 221 |
-
sequence_length (int): The length of the sequence for padding purposes.
|
| 222 |
-
melody_dim (int): The dimensionality of the melody matrix.
|
| 223 |
-
"""
|
| 224 |
-
def __init__(self, *args,
|
| 225 |
-
chords_card: int = 194,
|
| 226 |
-
sequence_length: int = 500,
|
| 227 |
-
melody_dim: int = 53, **kwargs):
|
| 228 |
-
self.null_chord = chords_card
|
| 229 |
-
self.sequence_len = sequence_length
|
| 230 |
-
self.melody_dim = melody_dim
|
| 231 |
-
super().__init__(*args, **kwargs)
|
| 232 |
-
|
| 233 |
-
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
| 234 |
-
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
|
| 235 |
-
This should be called before starting any real GPU work to avoid synchronization points.
|
| 236 |
-
This will return a dict matching conditioner names to their arbitrary tokenized representations.
|
| 237 |
-
|
| 238 |
-
Args:
|
| 239 |
-
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
|
| 240 |
-
text and wav conditions.
|
| 241 |
-
"""
|
| 242 |
-
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
|
| 243 |
-
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
|
| 244 |
-
f" but types were {set([type(x) for x in inputs])}"
|
| 245 |
-
)
|
| 246 |
-
|
| 247 |
-
output = {}
|
| 248 |
-
text = self._collate_text(inputs)
|
| 249 |
-
wavs = self._collate_wavs(inputs)
|
| 250 |
-
|
| 251 |
-
symbolic = self._collate_symbolic(inputs, self.conditioners.keys())
|
| 252 |
-
|
| 253 |
-
assert set(text.keys() | wavs.keys() | symbolic.keys()).issubset(set(self.conditioners.keys())), (
|
| 254 |
-
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
| 255 |
-
f"got {text.keys(), wavs.keys(), symbolic.keys()}"
|
| 256 |
-
)
|
| 257 |
-
|
| 258 |
-
for attribute, batch in chain(text.items(), wavs.items(), symbolic.items()):
|
| 259 |
-
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
| 260 |
-
return output
|
| 261 |
-
|
| 262 |
-
def _collate_symbolic(self, samples: tp.List[ConditioningAttributes],
|
| 263 |
-
conditioner_keys: tp.Set) -> tp.Dict[str, SymbolicCondition]:
|
| 264 |
-
output = {}
|
| 265 |
-
|
| 266 |
-
# collate if symbolic cond exists
|
| 267 |
-
if any(x in conditioner_keys for x in JascoCondConst.SYM.value):
|
| 268 |
-
|
| 269 |
-
for s in samples:
|
| 270 |
-
# hydrate with null chord if chords not exist - for inference support
|
| 271 |
-
if (s.symbolic == {} or
|
| 272 |
-
s.symbolic[JascoCondConst.CRD.value].frame_chords is None or
|
| 273 |
-
s.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] <= 1): # type: ignore
|
| 274 |
-
# no chords conditioning - fill with null chord token
|
| 275 |
-
s.symbolic[JascoCondConst.CRD.value] = SymbolicCondition(
|
| 276 |
-
frame_chords=torch.ones(self.sequence_len, dtype=torch.int32) * self.null_chord)
|
| 277 |
-
|
| 278 |
-
if (s.symbolic == {} or
|
| 279 |
-
s.symbolic[JascoCondConst.MLD.value].melody is None or
|
| 280 |
-
s.symbolic[JascoCondConst.MLD.value].melody.shape[-1] <= 1): # type: ignore
|
| 281 |
-
# no chords conditioning - fill with null chord token
|
| 282 |
-
s.symbolic[JascoCondConst.MLD.value] = SymbolicCondition(
|
| 283 |
-
melody=torch.zeros((self.melody_dim, self.sequence_len)))
|
| 284 |
-
|
| 285 |
-
if JascoCondConst.CRD.value in conditioner_keys:
|
| 286 |
-
# pad to max
|
| 287 |
-
max_seq_len = max(
|
| 288 |
-
[s.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] for s in samples]) # type: ignore
|
| 289 |
-
padded_chords = [
|
| 290 |
-
torch.cat((x.symbolic[JascoCondConst.CRD.value].frame_chords, # type: ignore
|
| 291 |
-
torch.ones(max_seq_len -
|
| 292 |
-
x.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1], # type: ignore
|
| 293 |
-
dtype=torch.int32) * self.null_chord))
|
| 294 |
-
for x in samples
|
| 295 |
-
]
|
| 296 |
-
output[JascoCondConst.CRD.value] = SymbolicCondition(frame_chords=torch.stack(padded_chords))
|
| 297 |
-
if JascoCondConst.MLD.value in conditioner_keys:
|
| 298 |
-
melodies = torch.stack([x.symbolic[JascoCondConst.MLD.value].melody for x in samples]) # type: ignore
|
| 299 |
-
output[JascoCondConst.MLD.value] = SymbolicCondition(melody=melodies)
|
| 300 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/lstm.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from torch import nn
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class StreamableLSTM(nn.Module):
|
| 11 |
-
"""LSTM without worrying about the hidden state, nor the layout of the data.
|
| 12 |
-
Expects input as convolutional layout.
|
| 13 |
-
"""
|
| 14 |
-
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
| 15 |
-
super().__init__()
|
| 16 |
-
self.skip = skip
|
| 17 |
-
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
| 18 |
-
|
| 19 |
-
def forward(self, x):
|
| 20 |
-
x = x.permute(2, 0, 1)
|
| 21 |
-
y, _ = self.lstm(x)
|
| 22 |
-
if self.skip:
|
| 23 |
-
y = y + x
|
| 24 |
-
y = y.permute(1, 2, 0)
|
| 25 |
-
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/rope.py
DELETED
|
@@ -1,125 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import typing as tp
|
| 8 |
-
|
| 9 |
-
from torch import nn
|
| 10 |
-
import torch
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class XPos(nn.Module):
|
| 14 |
-
"""Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
|
| 15 |
-
This applies an exponential decay to the RoPE rotation matrix.
|
| 16 |
-
|
| 17 |
-
Args:
|
| 18 |
-
dim (int): Embedding dimension.
|
| 19 |
-
smoothing (float): Smoothing factor applied to the decay rates.
|
| 20 |
-
base_scale (int): Base decay rate, given in terms of scaling time.
|
| 21 |
-
device (torch.device, optional): Device on which to initialize the module.
|
| 22 |
-
dtype (torch.dtype): dtype to use to generate the embedding.
|
| 23 |
-
"""
|
| 24 |
-
def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
|
| 25 |
-
device=None, dtype: torch.dtype = torch.float32):
|
| 26 |
-
super().__init__()
|
| 27 |
-
assert dim % 2 == 0
|
| 28 |
-
assert dtype in [torch.float64, torch.float32]
|
| 29 |
-
self.dtype = dtype
|
| 30 |
-
self.base_scale = base_scale
|
| 31 |
-
|
| 32 |
-
half_dim = dim // 2
|
| 33 |
-
adim = torch.arange(half_dim, device=device, dtype=dtype)
|
| 34 |
-
decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
|
| 35 |
-
self.register_buffer("decay_rates", decay_rates)
|
| 36 |
-
self.decay: tp.Optional[torch.Tensor] = None
|
| 37 |
-
|
| 38 |
-
def get_decay(self, start: int, end: int):
|
| 39 |
-
"""Create complex decay tensor, cache values for fast computation."""
|
| 40 |
-
if self.decay is None or end > self.decay.shape[0]:
|
| 41 |
-
assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
|
| 42 |
-
idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
|
| 43 |
-
power = idx / self.base_scale
|
| 44 |
-
scale = self.decay_rates ** power.unsqueeze(-1)
|
| 45 |
-
self.decay = torch.polar(scale, torch.zeros_like(scale))
|
| 46 |
-
return self.decay[start:end] # [T, C/2]
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class RotaryEmbedding(nn.Module):
|
| 50 |
-
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
|
| 51 |
-
|
| 52 |
-
Args:
|
| 53 |
-
dim (int): Embedding dimension (twice the number of frequencies).
|
| 54 |
-
max_period (float): Maximum period of the rotation frequencies.
|
| 55 |
-
xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
|
| 56 |
-
scale (float): Scale of positional embedding, set to 0 to deactivate.
|
| 57 |
-
device (torch.device, optional): Device on which to initialize the module.
|
| 58 |
-
dtype (torch.dtype): dtype to use to generate the embedding.
|
| 59 |
-
"""
|
| 60 |
-
def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
|
| 61 |
-
scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
|
| 62 |
-
super().__init__()
|
| 63 |
-
assert dim % 2 == 0
|
| 64 |
-
self.scale = scale
|
| 65 |
-
assert dtype in [torch.float64, torch.float32]
|
| 66 |
-
self.dtype = dtype
|
| 67 |
-
|
| 68 |
-
adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
|
| 69 |
-
frequencies = 1.0 / (max_period ** (adim / dim))
|
| 70 |
-
self.register_buffer("frequencies", frequencies)
|
| 71 |
-
self.rotation: tp.Optional[torch.Tensor] = None
|
| 72 |
-
|
| 73 |
-
self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
|
| 74 |
-
|
| 75 |
-
def get_rotation(self, start: int, end: int):
|
| 76 |
-
"""Create complex rotation tensor, cache values for fast computation."""
|
| 77 |
-
if self.rotation is None or end > self.rotation.shape[0]:
|
| 78 |
-
assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
|
| 79 |
-
idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
|
| 80 |
-
angles = torch.outer(idx, self.frequencies)
|
| 81 |
-
self.rotation = torch.polar(torch.ones_like(angles), angles)
|
| 82 |
-
return self.rotation[start:end]
|
| 83 |
-
|
| 84 |
-
def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
|
| 85 |
-
"""Apply rope rotation to query or key tensor."""
|
| 86 |
-
T = x.shape[time_dim]
|
| 87 |
-
target_shape = [1] * x.dim()
|
| 88 |
-
target_shape[time_dim] = T
|
| 89 |
-
target_shape[-1] = -1
|
| 90 |
-
rotation = self.get_rotation(start, start + T).view(target_shape)
|
| 91 |
-
|
| 92 |
-
if self.xpos:
|
| 93 |
-
decay = self.xpos.get_decay(start, start + T).view(target_shape)
|
| 94 |
-
else:
|
| 95 |
-
decay = 1.0
|
| 96 |
-
|
| 97 |
-
if invert_decay:
|
| 98 |
-
decay = decay ** -1
|
| 99 |
-
|
| 100 |
-
x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
|
| 101 |
-
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
|
| 102 |
-
x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)
|
| 103 |
-
|
| 104 |
-
return x_out.type_as(x)
|
| 105 |
-
|
| 106 |
-
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
|
| 107 |
-
""" Apply rope rotation to both query and key tensors.
|
| 108 |
-
Supports streaming mode, in which query and key are not expected to have the same shape.
|
| 109 |
-
In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
|
| 110 |
-
query will be [C] (typically C == 1).
|
| 111 |
-
|
| 112 |
-
Args:
|
| 113 |
-
query (torch.Tensor): Query to rotate.
|
| 114 |
-
key (torch.Tensor): Key to rotate.
|
| 115 |
-
start (int): Start index of the sequence for time offset.
|
| 116 |
-
time_dim (int): which dimension represent the time steps.
|
| 117 |
-
"""
|
| 118 |
-
query_timesteps = query.shape[time_dim]
|
| 119 |
-
key_timesteps = key.shape[time_dim]
|
| 120 |
-
streaming_offset = key_timesteps - query_timesteps
|
| 121 |
-
|
| 122 |
-
query_out = self.rotate(query, start + streaming_offset, time_dim)
|
| 123 |
-
key_out = self.rotate(key, start, time_dim, invert_decay=True)
|
| 124 |
-
|
| 125 |
-
return query_out, key_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/seanet.py
DELETED
|
@@ -1,258 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import typing as tp
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
|
| 12 |
-
from .conv import StreamableConv1d, StreamableConvTranspose1d
|
| 13 |
-
from .lstm import StreamableLSTM
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class SEANetResnetBlock(nn.Module):
|
| 17 |
-
"""Residual block from SEANet model.
|
| 18 |
-
|
| 19 |
-
Args:
|
| 20 |
-
dim (int): Dimension of the input/output.
|
| 21 |
-
kernel_sizes (list): List of kernel sizes for the convolutions.
|
| 22 |
-
dilations (list): List of dilations for the convolutions.
|
| 23 |
-
activation (str): Activation function.
|
| 24 |
-
activation_params (dict): Parameters to provide to the activation function.
|
| 25 |
-
norm (str): Normalization method.
|
| 26 |
-
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 27 |
-
causal (bool): Whether to use fully causal convolution.
|
| 28 |
-
pad_mode (str): Padding mode for the convolutions.
|
| 29 |
-
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 30 |
-
true_skip (bool): Whether to use true skip connection or a simple
|
| 31 |
-
(streamable) convolution as the skip connection.
|
| 32 |
-
"""
|
| 33 |
-
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
|
| 34 |
-
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
| 35 |
-
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
|
| 36 |
-
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
|
| 37 |
-
super().__init__()
|
| 38 |
-
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
|
| 39 |
-
act = getattr(nn, activation)
|
| 40 |
-
hidden = dim // compress
|
| 41 |
-
block = []
|
| 42 |
-
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
| 43 |
-
in_chs = dim if i == 0 else hidden
|
| 44 |
-
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
| 45 |
-
block += [
|
| 46 |
-
act(**activation_params),
|
| 47 |
-
StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
|
| 48 |
-
norm=norm, norm_kwargs=norm_params,
|
| 49 |
-
causal=causal, pad_mode=pad_mode),
|
| 50 |
-
]
|
| 51 |
-
self.block = nn.Sequential(*block)
|
| 52 |
-
self.shortcut: nn.Module
|
| 53 |
-
if true_skip:
|
| 54 |
-
self.shortcut = nn.Identity()
|
| 55 |
-
else:
|
| 56 |
-
self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
|
| 57 |
-
causal=causal, pad_mode=pad_mode)
|
| 58 |
-
|
| 59 |
-
def forward(self, x):
|
| 60 |
-
return self.shortcut(x) + self.block(x)
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class SEANetEncoder(nn.Module):
|
| 64 |
-
"""SEANet encoder.
|
| 65 |
-
|
| 66 |
-
Args:
|
| 67 |
-
channels (int): Audio channels.
|
| 68 |
-
dimension (int): Intermediate representation dimension.
|
| 69 |
-
n_filters (int): Base width for the model.
|
| 70 |
-
n_residual_layers (int): nb of residual layers.
|
| 71 |
-
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
| 72 |
-
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
| 73 |
-
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
|
| 74 |
-
activation (str): Activation function.
|
| 75 |
-
activation_params (dict): Parameters to provide to the activation function.
|
| 76 |
-
norm (str): Normalization method.
|
| 77 |
-
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 78 |
-
kernel_size (int): Kernel size for the initial convolution.
|
| 79 |
-
last_kernel_size (int): Kernel size for the initial convolution.
|
| 80 |
-
residual_kernel_size (int): Kernel size for the residual layers.
|
| 81 |
-
dilation_base (int): How much to increase the dilation with each layer.
|
| 82 |
-
causal (bool): Whether to use fully causal convolution.
|
| 83 |
-
pad_mode (str): Padding mode for the convolutions.
|
| 84 |
-
true_skip (bool): Whether to use true skip connection or a simple
|
| 85 |
-
(streamable) convolution as the skip connection in the residual network blocks.
|
| 86 |
-
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 87 |
-
lstm (int): Number of LSTM layers at the end of the encoder.
|
| 88 |
-
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
| 89 |
-
For the encoder, it corresponds to the N first blocks.
|
| 90 |
-
"""
|
| 91 |
-
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
| 92 |
-
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
| 93 |
-
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
| 94 |
-
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
| 95 |
-
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
| 96 |
-
disable_norm_outer_blocks: int = 0):
|
| 97 |
-
super().__init__()
|
| 98 |
-
self.channels = channels
|
| 99 |
-
self.dimension = dimension
|
| 100 |
-
self.n_filters = n_filters
|
| 101 |
-
self.ratios = list(reversed(ratios))
|
| 102 |
-
del ratios
|
| 103 |
-
self.n_residual_layers = n_residual_layers
|
| 104 |
-
self.hop_length = np.prod(self.ratios)
|
| 105 |
-
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
| 106 |
-
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
| 107 |
-
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
| 108 |
-
"Number of blocks for which to disable norm is invalid." \
|
| 109 |
-
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
| 110 |
-
|
| 111 |
-
act = getattr(nn, activation)
|
| 112 |
-
mult = 1
|
| 113 |
-
model: tp.List[nn.Module] = [
|
| 114 |
-
StreamableConv1d(channels, mult * n_filters, kernel_size,
|
| 115 |
-
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
| 116 |
-
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
| 117 |
-
]
|
| 118 |
-
# Downsample to raw audio scale
|
| 119 |
-
for i, ratio in enumerate(self.ratios):
|
| 120 |
-
block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
|
| 121 |
-
# Add residual layers
|
| 122 |
-
for j in range(n_residual_layers):
|
| 123 |
-
model += [
|
| 124 |
-
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
|
| 125 |
-
dilations=[dilation_base ** j, 1],
|
| 126 |
-
norm=block_norm, norm_params=norm_params,
|
| 127 |
-
activation=activation, activation_params=activation_params,
|
| 128 |
-
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
| 129 |
-
|
| 130 |
-
# Add downsampling layers
|
| 131 |
-
model += [
|
| 132 |
-
act(**activation_params),
|
| 133 |
-
StreamableConv1d(mult * n_filters, mult * n_filters * 2,
|
| 134 |
-
kernel_size=ratio * 2, stride=ratio,
|
| 135 |
-
norm=block_norm, norm_kwargs=norm_params,
|
| 136 |
-
causal=causal, pad_mode=pad_mode),
|
| 137 |
-
]
|
| 138 |
-
mult *= 2
|
| 139 |
-
|
| 140 |
-
if lstm:
|
| 141 |
-
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
| 142 |
-
|
| 143 |
-
model += [
|
| 144 |
-
act(**activation_params),
|
| 145 |
-
StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
|
| 146 |
-
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
| 147 |
-
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
| 148 |
-
]
|
| 149 |
-
|
| 150 |
-
self.model = nn.Sequential(*model)
|
| 151 |
-
|
| 152 |
-
def forward(self, x):
|
| 153 |
-
return self.model(x)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
class SEANetDecoder(nn.Module):
|
| 157 |
-
"""SEANet decoder.
|
| 158 |
-
|
| 159 |
-
Args:
|
| 160 |
-
channels (int): Audio channels.
|
| 161 |
-
dimension (int): Intermediate representation dimension.
|
| 162 |
-
n_filters (int): Base width for the model.
|
| 163 |
-
n_residual_layers (int): nb of residual layers.
|
| 164 |
-
ratios (Sequence[int]): kernel size and stride ratios.
|
| 165 |
-
activation (str): Activation function.
|
| 166 |
-
activation_params (dict): Parameters to provide to the activation function.
|
| 167 |
-
final_activation (str): Final activation function after all convolutions.
|
| 168 |
-
final_activation_params (dict): Parameters to provide to the activation function.
|
| 169 |
-
norm (str): Normalization method.
|
| 170 |
-
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 171 |
-
kernel_size (int): Kernel size for the initial convolution.
|
| 172 |
-
last_kernel_size (int): Kernel size for the initial convolution.
|
| 173 |
-
residual_kernel_size (int): Kernel size for the residual layers.
|
| 174 |
-
dilation_base (int): How much to increase the dilation with each layer.
|
| 175 |
-
causal (bool): Whether to use fully causal convolution.
|
| 176 |
-
pad_mode (str): Padding mode for the convolutions.
|
| 177 |
-
true_skip (bool): Whether to use true skip connection or a simple.
|
| 178 |
-
(streamable) convolution as the skip connection in the residual network blocks.
|
| 179 |
-
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 180 |
-
lstm (int): Number of LSTM layers at the end of the encoder.
|
| 181 |
-
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
| 182 |
-
For the decoder, it corresponds to the N last blocks.
|
| 183 |
-
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
| 184 |
-
If equal to 1.0, it means that all the trimming is done at the right.
|
| 185 |
-
"""
|
| 186 |
-
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
| 187 |
-
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
| 188 |
-
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
|
| 189 |
-
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
| 190 |
-
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
| 191 |
-
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
| 192 |
-
disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
|
| 193 |
-
super().__init__()
|
| 194 |
-
self.dimension = dimension
|
| 195 |
-
self.channels = channels
|
| 196 |
-
self.n_filters = n_filters
|
| 197 |
-
self.ratios = ratios
|
| 198 |
-
del ratios
|
| 199 |
-
self.n_residual_layers = n_residual_layers
|
| 200 |
-
self.hop_length = np.prod(self.ratios)
|
| 201 |
-
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
| 202 |
-
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
| 203 |
-
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
| 204 |
-
"Number of blocks for which to disable norm is invalid." \
|
| 205 |
-
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
| 206 |
-
|
| 207 |
-
act = getattr(nn, activation)
|
| 208 |
-
mult = int(2 ** len(self.ratios))
|
| 209 |
-
model: tp.List[nn.Module] = [
|
| 210 |
-
StreamableConv1d(dimension, mult * n_filters, kernel_size,
|
| 211 |
-
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
| 212 |
-
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
| 213 |
-
]
|
| 214 |
-
|
| 215 |
-
if lstm:
|
| 216 |
-
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
| 217 |
-
|
| 218 |
-
# Upsample to raw audio scale
|
| 219 |
-
for i, ratio in enumerate(self.ratios):
|
| 220 |
-
block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
|
| 221 |
-
# Add upsampling layers
|
| 222 |
-
model += [
|
| 223 |
-
act(**activation_params),
|
| 224 |
-
StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
|
| 225 |
-
kernel_size=ratio * 2, stride=ratio,
|
| 226 |
-
norm=block_norm, norm_kwargs=norm_params,
|
| 227 |
-
causal=causal, trim_right_ratio=trim_right_ratio),
|
| 228 |
-
]
|
| 229 |
-
# Add residual layers
|
| 230 |
-
for j in range(n_residual_layers):
|
| 231 |
-
model += [
|
| 232 |
-
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
|
| 233 |
-
dilations=[dilation_base ** j, 1],
|
| 234 |
-
activation=activation, activation_params=activation_params,
|
| 235 |
-
norm=block_norm, norm_params=norm_params, causal=causal,
|
| 236 |
-
pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
| 237 |
-
|
| 238 |
-
mult //= 2
|
| 239 |
-
|
| 240 |
-
# Add final layers
|
| 241 |
-
model += [
|
| 242 |
-
act(**activation_params),
|
| 243 |
-
StreamableConv1d(n_filters, channels, last_kernel_size,
|
| 244 |
-
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
| 245 |
-
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
| 246 |
-
]
|
| 247 |
-
# Add optional final activation to decoder (eg. tanh)
|
| 248 |
-
if final_activation is not None:
|
| 249 |
-
final_act = getattr(nn, final_activation)
|
| 250 |
-
final_activation_params = final_activation_params or {}
|
| 251 |
-
model += [
|
| 252 |
-
final_act(**final_activation_params)
|
| 253 |
-
]
|
| 254 |
-
self.model = nn.Sequential(*model)
|
| 255 |
-
|
| 256 |
-
def forward(self, z):
|
| 257 |
-
y = self.model(z)
|
| 258 |
-
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/streaming.py
DELETED
|
@@ -1,135 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Streaming module API that should be implemented by all Streaming components,
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
from contextlib import contextmanager
|
| 12 |
-
import typing as tp
|
| 13 |
-
from torch import nn
|
| 14 |
-
import torch
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
State = tp.Dict[str, torch.Tensor]
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class StreamingModule(nn.Module):
|
| 21 |
-
"""Common API for streaming components.
|
| 22 |
-
|
| 23 |
-
Each streaming component has a streaming state, which is just a dict[str, Tensor].
|
| 24 |
-
By convention, the first dim of each tensor must be the batch size.
|
| 25 |
-
Don't use dots in the key names, as this would clash with submodules
|
| 26 |
-
(like in state_dict).
|
| 27 |
-
|
| 28 |
-
If `self._is_streaming` is True, the component should use and remember
|
| 29 |
-
the proper state inside `self._streaming_state`.
|
| 30 |
-
|
| 31 |
-
To set a streaming component in streaming state, use
|
| 32 |
-
|
| 33 |
-
with module.streaming():
|
| 34 |
-
...
|
| 35 |
-
|
| 36 |
-
This will automatically reset the streaming state when exiting the context manager.
|
| 37 |
-
This also automatically propagates to all streaming children module.
|
| 38 |
-
|
| 39 |
-
Some module might also implement the `StreamingModule.flush` method, although
|
| 40 |
-
this one is trickier, as all parents module must be StreamingModule and implement
|
| 41 |
-
it as well for it to work properly. See `StreamingSequential` after.
|
| 42 |
-
"""
|
| 43 |
-
def __init__(self) -> None:
|
| 44 |
-
super().__init__()
|
| 45 |
-
self._streaming_state: State = {}
|
| 46 |
-
self._is_streaming = False
|
| 47 |
-
|
| 48 |
-
def _apply_named_streaming(self, fn: tp.Any):
|
| 49 |
-
for name, module in self.named_modules():
|
| 50 |
-
if isinstance(module, StreamingModule):
|
| 51 |
-
fn(name, module)
|
| 52 |
-
|
| 53 |
-
def _set_streaming(self, streaming: bool):
|
| 54 |
-
def _set_streaming(name, module):
|
| 55 |
-
module._is_streaming = streaming
|
| 56 |
-
self._apply_named_streaming(_set_streaming)
|
| 57 |
-
|
| 58 |
-
@contextmanager
|
| 59 |
-
def streaming(self):
|
| 60 |
-
"""Context manager to enter streaming mode. Reset streaming state on exit.
|
| 61 |
-
"""
|
| 62 |
-
self._set_streaming(True)
|
| 63 |
-
try:
|
| 64 |
-
yield
|
| 65 |
-
finally:
|
| 66 |
-
self._set_streaming(False)
|
| 67 |
-
self.reset_streaming()
|
| 68 |
-
|
| 69 |
-
def reset_streaming(self):
|
| 70 |
-
"""Reset the streaming state.
|
| 71 |
-
"""
|
| 72 |
-
def _reset(name: str, module: StreamingModule):
|
| 73 |
-
module._streaming_state.clear()
|
| 74 |
-
|
| 75 |
-
self._apply_named_streaming(_reset)
|
| 76 |
-
|
| 77 |
-
def get_streaming_state(self) -> State:
|
| 78 |
-
"""Return the streaming state, including that of sub-modules.
|
| 79 |
-
"""
|
| 80 |
-
state: State = {}
|
| 81 |
-
|
| 82 |
-
def _add(name: str, module: StreamingModule):
|
| 83 |
-
if name:
|
| 84 |
-
name += "."
|
| 85 |
-
for key, value in module._streaming_state.items():
|
| 86 |
-
state[name + key] = value
|
| 87 |
-
|
| 88 |
-
self._apply_named_streaming(_add)
|
| 89 |
-
return state
|
| 90 |
-
|
| 91 |
-
def set_streaming_state(self, state: State):
|
| 92 |
-
"""Set the streaming state, including that of sub-modules.
|
| 93 |
-
"""
|
| 94 |
-
state = dict(state)
|
| 95 |
-
|
| 96 |
-
def _set(name: str, module: StreamingModule):
|
| 97 |
-
if name:
|
| 98 |
-
name += "."
|
| 99 |
-
module._streaming_state.clear()
|
| 100 |
-
for key, value in list(state.items()):
|
| 101 |
-
# complexity is not ideal here, but probably fine.
|
| 102 |
-
if key.startswith(name):
|
| 103 |
-
local_key = key[len(name):]
|
| 104 |
-
if '.' not in local_key:
|
| 105 |
-
module._streaming_state[local_key] = value
|
| 106 |
-
del state[key]
|
| 107 |
-
|
| 108 |
-
self._apply_named_streaming(_set)
|
| 109 |
-
assert len(state) == 0, list(state.keys())
|
| 110 |
-
|
| 111 |
-
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
| 112 |
-
"""Flush any remaining outputs that were waiting for completion.
|
| 113 |
-
Typically, for convolutions, this will add the final padding
|
| 114 |
-
and process the last buffer.
|
| 115 |
-
|
| 116 |
-
This should take an optional argument `x`, which will be provided
|
| 117 |
-
if a module before this one in the streaming pipeline has already
|
| 118 |
-
spitted out a flushed out buffer.
|
| 119 |
-
"""
|
| 120 |
-
if x is None:
|
| 121 |
-
return None
|
| 122 |
-
else:
|
| 123 |
-
return self(x)
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
class StreamingSequential(StreamingModule, nn.Sequential):
|
| 127 |
-
"""A streaming compatible alternative of `nn.Sequential`.
|
| 128 |
-
"""
|
| 129 |
-
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
| 130 |
-
for module in self:
|
| 131 |
-
if isinstance(module, StreamingModule):
|
| 132 |
-
x = module.flush(x)
|
| 133 |
-
elif x is not None:
|
| 134 |
-
x = module(x)
|
| 135 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/transformer.py
DELETED
|
@@ -1,755 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Transformer model, with streaming support, xformer attention support
|
| 9 |
-
and easy causal attention with a potentially finite receptive field.
|
| 10 |
-
|
| 11 |
-
See `StreamingTransformer` for more information.
|
| 12 |
-
|
| 13 |
-
Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
import typing as tp
|
| 17 |
-
|
| 18 |
-
from einops import rearrange
|
| 19 |
-
import torch
|
| 20 |
-
import torch.nn as nn
|
| 21 |
-
from torch.nn import functional as F
|
| 22 |
-
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
| 23 |
-
from xformers import ops
|
| 24 |
-
|
| 25 |
-
from .rope import RotaryEmbedding
|
| 26 |
-
from .streaming import StreamingModule
|
| 27 |
-
|
| 28 |
-
_efficient_attention_backend: str = 'torch'
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def set_efficient_attention_backend(backend: str = 'torch'):
|
| 32 |
-
# Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
|
| 33 |
-
global _efficient_attention_backend
|
| 34 |
-
assert _efficient_attention_backend in ['xformers', 'torch']
|
| 35 |
-
_efficient_attention_backend = backend
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def _get_attention_time_dimension(memory_efficient: bool) -> int:
|
| 39 |
-
if _efficient_attention_backend == 'torch' and memory_efficient:
|
| 40 |
-
return 2
|
| 41 |
-
else:
|
| 42 |
-
return 1
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _is_profiled() -> bool:
|
| 46 |
-
# Return true if we are currently running with a xformers profiler activated.
|
| 47 |
-
try:
|
| 48 |
-
from xformers.profiler import profiler
|
| 49 |
-
except ImportError:
|
| 50 |
-
return False
|
| 51 |
-
return profiler._Profiler._CURRENT_PROFILER is not None
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
|
| 55 |
-
"""Create normalization module for transformer encoder layer.
|
| 56 |
-
|
| 57 |
-
Args:
|
| 58 |
-
norm_type (str): Normalization method.
|
| 59 |
-
dim (int): Dimension of the normalized layer.
|
| 60 |
-
**kwargs (dict): Additional parameters for normalization layer.
|
| 61 |
-
Returns:
|
| 62 |
-
nn.Module: Normalization module.
|
| 63 |
-
"""
|
| 64 |
-
if norm_type == 'layer_norm':
|
| 65 |
-
return nn.LayerNorm(dim, eps=1e-5, **kwargs)
|
| 66 |
-
else:
|
| 67 |
-
raise ValueError(f"Unknown norm type: {norm_type}")
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
|
| 71 |
-
dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
| 72 |
-
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
| 73 |
-
|
| 74 |
-
Args:
|
| 75 |
-
positions (torch.Tensor): LongTensor of positions.
|
| 76 |
-
dim (int): Dimension of the embedding.
|
| 77 |
-
max_period (float): Maximum period of the cosine/sine functions.
|
| 78 |
-
dtype (torch.dtype or str): dtype to use to generate the embedding.
|
| 79 |
-
Returns:
|
| 80 |
-
torch.Tensor: Sinusoidal positional embedding.
|
| 81 |
-
"""
|
| 82 |
-
# We aim for BTC format
|
| 83 |
-
assert dim % 2 == 0
|
| 84 |
-
half_dim = dim // 2
|
| 85 |
-
positions = positions.to(dtype)
|
| 86 |
-
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
|
| 87 |
-
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
|
| 88 |
-
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
|
| 89 |
-
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
|
| 93 |
-
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
|
| 94 |
-
if n_rep == 1:
|
| 95 |
-
return x
|
| 96 |
-
if _efficient_attention_backend == 'torch' and memory_efficient:
|
| 97 |
-
bs, n_kv_heads, slen, head_dim = x.shape
|
| 98 |
-
return (
|
| 99 |
-
x[:, :, None, :, :]
|
| 100 |
-
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
| 101 |
-
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
| 102 |
-
)
|
| 103 |
-
else:
|
| 104 |
-
bs, slen, n_kv_heads, head_dim = x.shape
|
| 105 |
-
return (
|
| 106 |
-
x[:, :, :, None, :]
|
| 107 |
-
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
| 108 |
-
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class LayerScale(nn.Module):
|
| 113 |
-
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
| 114 |
-
This rescales diagonally the residual outputs close to 0, with a learnt scale.
|
| 115 |
-
|
| 116 |
-
Args:
|
| 117 |
-
channels (int): Number of channels.
|
| 118 |
-
init (float): Initial scale.
|
| 119 |
-
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
|
| 120 |
-
device (torch.device or str, optional): Device on which to initialize the module.
|
| 121 |
-
dtype (torch.dtype, optional): dtype to use to initialize the module.
|
| 122 |
-
"""
|
| 123 |
-
def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
|
| 124 |
-
device=None, dtype=None):
|
| 125 |
-
super().__init__()
|
| 126 |
-
self.channel_last = channel_last
|
| 127 |
-
self.scale = nn.Parameter(
|
| 128 |
-
torch.full((channels,), init,
|
| 129 |
-
requires_grad=True, device=device, dtype=dtype))
|
| 130 |
-
|
| 131 |
-
def forward(self, x: torch.Tensor):
|
| 132 |
-
if self.channel_last:
|
| 133 |
-
return self.scale * x
|
| 134 |
-
else:
|
| 135 |
-
return self.scale[:, None] * x
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
class StreamingMultiheadAttention(StreamingModule):
|
| 139 |
-
"""Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
|
| 140 |
-
|
| 141 |
-
Args:
|
| 142 |
-
embed_dim (int): Dimension to project to.
|
| 143 |
-
num_heads (int): Number of heads.
|
| 144 |
-
dropout (float): Dropout level.
|
| 145 |
-
bias (bool): Use bias in projections.
|
| 146 |
-
causal (bool): Causal mask applied automatically.
|
| 147 |
-
past_context (int, optional): Receptive field for the causal mask, infinite if None.
|
| 148 |
-
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
| 149 |
-
memory_efficient (bool): Use xformers based memory efficient attention.
|
| 150 |
-
attention_as_float32 (bool): Perform the attention as float32
|
| 151 |
-
(especially important with memory_efficient as autocast won't do this automatically).
|
| 152 |
-
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
| 153 |
-
cross_attention: Should be true when used as a cross attention.
|
| 154 |
-
All keys and values must be available at once, streaming is only for the queries.
|
| 155 |
-
Cannot be used with `causal` or `rope` (as it wouldn't make sens to
|
| 156 |
-
interpret the time steps in the keys relative to those in the queries).
|
| 157 |
-
safe_streaming (bool): Bug fix, will go away with xformers update.
|
| 158 |
-
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
|
| 159 |
-
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
|
| 160 |
-
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
|
| 161 |
-
device (torch.device, optional): Device on which to initialize.
|
| 162 |
-
dtype (torch.dtype, optional): dtype to use.
|
| 163 |
-
"""
|
| 164 |
-
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
|
| 165 |
-
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
|
| 166 |
-
memory_efficient: bool = False, attention_as_float32: bool = False,
|
| 167 |
-
rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
|
| 168 |
-
safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
|
| 169 |
-
device=None, dtype=None):
|
| 170 |
-
super().__init__()
|
| 171 |
-
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 172 |
-
if past_context is not None:
|
| 173 |
-
assert causal
|
| 174 |
-
|
| 175 |
-
self.embed_dim = embed_dim
|
| 176 |
-
self.causal = causal
|
| 177 |
-
self.past_context = past_context
|
| 178 |
-
self.memory_efficient = memory_efficient
|
| 179 |
-
self.attention_as_float32 = attention_as_float32
|
| 180 |
-
self.rope = rope
|
| 181 |
-
self.cross_attention = cross_attention
|
| 182 |
-
self.safe_streaming = safe_streaming
|
| 183 |
-
self.num_heads = num_heads
|
| 184 |
-
self.dropout = dropout
|
| 185 |
-
self.kv_repeat = kv_repeat
|
| 186 |
-
if cross_attention:
|
| 187 |
-
assert not causal, "Causal cannot work with cross attention."
|
| 188 |
-
assert rope is None, "Rope cannot work with cross attention."
|
| 189 |
-
|
| 190 |
-
if memory_efficient:
|
| 191 |
-
_verify_xformers_memory_efficient_compat()
|
| 192 |
-
|
| 193 |
-
self.custom = _is_custom(custom, memory_efficient)
|
| 194 |
-
if self.custom:
|
| 195 |
-
out_dim = embed_dim
|
| 196 |
-
assert num_heads % kv_repeat == 0
|
| 197 |
-
assert not cross_attention or kv_repeat == 1
|
| 198 |
-
num_kv = num_heads // kv_repeat
|
| 199 |
-
kv_dim = (embed_dim // num_heads) * num_kv
|
| 200 |
-
out_dim += 2 * kv_dim
|
| 201 |
-
in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
|
| 202 |
-
# We try to follow the default PyTorch MHA convention, to easily compare results.
|
| 203 |
-
self.in_proj_weight = in_proj.weight
|
| 204 |
-
self.in_proj_bias = in_proj.bias
|
| 205 |
-
if bias:
|
| 206 |
-
self.in_proj_bias.data.zero_() # Following Pytorch convention
|
| 207 |
-
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
| 208 |
-
if bias:
|
| 209 |
-
self.out_proj.bias.data.zero_()
|
| 210 |
-
else:
|
| 211 |
-
assert not qk_layer_norm
|
| 212 |
-
assert kv_repeat == 1
|
| 213 |
-
self.mha = nn.MultiheadAttention(
|
| 214 |
-
embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
|
| 215 |
-
**factory_kwargs)
|
| 216 |
-
self.qk_layer_norm = qk_layer_norm
|
| 217 |
-
if qk_layer_norm:
|
| 218 |
-
assert self.custom
|
| 219 |
-
assert kv_repeat == 1
|
| 220 |
-
ln_dim = embed_dim
|
| 221 |
-
self.q_layer_norm = nn.LayerNorm(ln_dim)
|
| 222 |
-
self.k_layer_norm = nn.LayerNorm(ln_dim)
|
| 223 |
-
|
| 224 |
-
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
| 225 |
-
if not self.custom:
|
| 226 |
-
# Support compat with regular MHA
|
| 227 |
-
keys = [n for n, _ in self.mha.named_parameters()]
|
| 228 |
-
for key in keys:
|
| 229 |
-
if prefix + key in state_dict:
|
| 230 |
-
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
|
| 231 |
-
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
| 232 |
-
|
| 233 |
-
def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
|
| 234 |
-
# Return a causal mask, accounting for potentially stored past keys/values
|
| 235 |
-
# We actually return a bias for the attention score, as this has the same
|
| 236 |
-
# convention both in the builtin MHA in Pytorch, and Xformers functions.
|
| 237 |
-
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
| 238 |
-
if self.memory_efficient:
|
| 239 |
-
from xformers.ops import LowerTriangularMask
|
| 240 |
-
if current_steps == 1:
|
| 241 |
-
# If we only have one step, then we do not need a mask.
|
| 242 |
-
return None
|
| 243 |
-
elif 'past_keys' in self._streaming_state:
|
| 244 |
-
raise RuntimeError("Not supported at the moment")
|
| 245 |
-
else:
|
| 246 |
-
# Then we can safely use a lower triangular mask
|
| 247 |
-
return LowerTriangularMask()
|
| 248 |
-
if self._streaming_state:
|
| 249 |
-
past_keys = self._streaming_state['past_keys']
|
| 250 |
-
past_steps = past_keys.shape[time_dim]
|
| 251 |
-
else:
|
| 252 |
-
past_steps = 0
|
| 253 |
-
|
| 254 |
-
queries_pos = torch.arange(
|
| 255 |
-
past_steps, current_steps + past_steps, device=device).view(-1, 1)
|
| 256 |
-
keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
|
| 257 |
-
delta = queries_pos - keys_pos
|
| 258 |
-
valid = delta >= 0
|
| 259 |
-
if self.past_context is not None:
|
| 260 |
-
valid &= (delta <= self.past_context)
|
| 261 |
-
return torch.where(
|
| 262 |
-
valid,
|
| 263 |
-
torch.zeros([], device=device, dtype=dtype),
|
| 264 |
-
torch.full([], float('-inf'), device=device, dtype=dtype))
|
| 265 |
-
|
| 266 |
-
def _complete_kv(self, k, v):
|
| 267 |
-
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
| 268 |
-
if self.cross_attention:
|
| 269 |
-
# With cross attention we assume all keys and values
|
| 270 |
-
# are already available, and streaming is with respect
|
| 271 |
-
# to the queries only.
|
| 272 |
-
return k, v
|
| 273 |
-
# Complete the key/value pair using the streaming state.
|
| 274 |
-
if self._streaming_state:
|
| 275 |
-
pk = self._streaming_state['past_keys']
|
| 276 |
-
nk = torch.cat([pk, k], dim=time_dim)
|
| 277 |
-
if v is k:
|
| 278 |
-
nv = nk
|
| 279 |
-
else:
|
| 280 |
-
pv = self._streaming_state['past_values']
|
| 281 |
-
nv = torch.cat([pv, v], dim=time_dim)
|
| 282 |
-
else:
|
| 283 |
-
nk = k
|
| 284 |
-
nv = v
|
| 285 |
-
|
| 286 |
-
assert nk.shape[time_dim] == nv.shape[time_dim]
|
| 287 |
-
offset = 0
|
| 288 |
-
if self.past_context is not None:
|
| 289 |
-
offset = max(0, nk.shape[time_dim] - self.past_context)
|
| 290 |
-
if self._is_streaming:
|
| 291 |
-
self._streaming_state['past_keys'] = nk[:, offset:]
|
| 292 |
-
if v is not k:
|
| 293 |
-
self._streaming_state['past_values'] = nv[:, offset:]
|
| 294 |
-
if 'offset' in self._streaming_state:
|
| 295 |
-
self._streaming_state['offset'] += offset
|
| 296 |
-
else:
|
| 297 |
-
self._streaming_state['offset'] = torch.tensor(0)
|
| 298 |
-
return nk, nv
|
| 299 |
-
|
| 300 |
-
def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
|
| 301 |
-
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
| 302 |
-
# Apply rope embeddings to query and key tensors.
|
| 303 |
-
assert self.rope is not None
|
| 304 |
-
if 'past_keys' in self._streaming_state:
|
| 305 |
-
past_keys_offset = self._streaming_state['past_keys'].shape[1]
|
| 306 |
-
else:
|
| 307 |
-
past_keys_offset = 0
|
| 308 |
-
if 'offset' in self._streaming_state:
|
| 309 |
-
past_context_offset = int(self._streaming_state['offset'].item())
|
| 310 |
-
else:
|
| 311 |
-
past_context_offset = 0
|
| 312 |
-
streaming_offset = past_context_offset + past_keys_offset
|
| 313 |
-
return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)
|
| 314 |
-
|
| 315 |
-
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
| 316 |
-
key_padding_mask=None, need_weights=False, attn_mask=None,
|
| 317 |
-
average_attn_weights=True, is_causal=False):
|
| 318 |
-
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
|
| 319 |
-
"use the causal args in the constructor.")
|
| 320 |
-
|
| 321 |
-
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
| 322 |
-
if time_dim == 2:
|
| 323 |
-
layout = "b h t d"
|
| 324 |
-
else:
|
| 325 |
-
layout = "b t h d"
|
| 326 |
-
dtype = query.dtype
|
| 327 |
-
if self._is_streaming:
|
| 328 |
-
assert self.causal or self.cross_attention, \
|
| 329 |
-
"Streaming only available for causal or cross attention"
|
| 330 |
-
|
| 331 |
-
custom_attn_mask = attn_mask is not None
|
| 332 |
-
|
| 333 |
-
if self.causal:
|
| 334 |
-
assert attn_mask is None
|
| 335 |
-
# At the moment we specialize only for the self-attention case.
|
| 336 |
-
assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
|
| 337 |
-
assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
|
| 338 |
-
attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
|
| 339 |
-
|
| 340 |
-
if self.custom:
|
| 341 |
-
# custom implementation
|
| 342 |
-
assert need_weights is False
|
| 343 |
-
assert key_padding_mask is None
|
| 344 |
-
if self.cross_attention:
|
| 345 |
-
# Different queries, keys, values, we have to spit manually the weights
|
| 346 |
-
# before applying the linear.
|
| 347 |
-
dim = self.in_proj_weight.shape[0] // 3
|
| 348 |
-
if self.in_proj_bias is None:
|
| 349 |
-
bias_q, bias_k, bias_v = None, None, None
|
| 350 |
-
else:
|
| 351 |
-
bias_q = self.in_proj_bias[:dim]
|
| 352 |
-
bias_k = self.in_proj_bias[dim: 2 * dim]
|
| 353 |
-
bias_v = self.in_proj_bias[2 * dim:]
|
| 354 |
-
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
|
| 355 |
-
# todo: when streaming, we could actually save k, v and check the shape actually match.
|
| 356 |
-
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
|
| 357 |
-
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
|
| 358 |
-
if self.qk_layer_norm is True:
|
| 359 |
-
q = self.q_layer_norm(q)
|
| 360 |
-
k = self.k_layer_norm(k)
|
| 361 |
-
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
| 362 |
-
else:
|
| 363 |
-
if not _is_profiled():
|
| 364 |
-
# profiling breaks that propertysomehow.
|
| 365 |
-
assert query is key, "specialized implementation"
|
| 366 |
-
assert value is key, "specialized implementation"
|
| 367 |
-
projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
|
| 368 |
-
if self.kv_repeat == 1:
|
| 369 |
-
if time_dim == 2:
|
| 370 |
-
bound_layout = "b h p t d"
|
| 371 |
-
else:
|
| 372 |
-
bound_layout = "b t p h d"
|
| 373 |
-
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
| 374 |
-
q, k, v = ops.unbind(packed, dim=2)
|
| 375 |
-
else:
|
| 376 |
-
embed_dim = self.embed_dim
|
| 377 |
-
per_head_dim = (embed_dim // self.num_heads)
|
| 378 |
-
kv_heads = self.num_heads // self.kv_repeat
|
| 379 |
-
q = projected[:, :, :embed_dim]
|
| 380 |
-
start = embed_dim
|
| 381 |
-
end = start + per_head_dim * kv_heads
|
| 382 |
-
k = projected[:, :, start: end]
|
| 383 |
-
v = projected[:, :, end:]
|
| 384 |
-
q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
|
| 385 |
-
k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
|
| 386 |
-
v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
|
| 387 |
-
|
| 388 |
-
if self.qk_layer_norm is True:
|
| 389 |
-
assert self.kv_repeat == 1
|
| 390 |
-
q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
|
| 391 |
-
q = self.q_layer_norm(q)
|
| 392 |
-
k = self.k_layer_norm(k)
|
| 393 |
-
q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
|
| 394 |
-
if self.rope:
|
| 395 |
-
q, k = self._apply_rope(q, k)
|
| 396 |
-
k, v = self._complete_kv(k, v)
|
| 397 |
-
if self.kv_repeat > 1:
|
| 398 |
-
k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient)
|
| 399 |
-
v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient)
|
| 400 |
-
if self.attention_as_float32:
|
| 401 |
-
q, k, v = [x.float() for x in [q, k, v]]
|
| 402 |
-
if self.memory_efficient:
|
| 403 |
-
if custom_attn_mask:
|
| 404 |
-
# When using a custom attn mask:
|
| 405 |
-
# Move to query's device, repeat for each sample, remove align8 padding
|
| 406 |
-
seq_len = query.shape[1]
|
| 407 |
-
attn_mask = attn_mask.to(q.dtype)
|
| 408 |
-
attn_mask = attn_mask.repeat((q.shape[0], 1, 1, 1))
|
| 409 |
-
attn_mask = attn_mask[..., :seq_len, :seq_len]
|
| 410 |
-
|
| 411 |
-
p = self.dropout if self.training else 0
|
| 412 |
-
if _efficient_attention_backend == 'torch':
|
| 413 |
-
x = torch.nn.functional.scaled_dot_product_attention(
|
| 414 |
-
q, k, v, is_causal=attn_mask is not None, dropout_p=p)
|
| 415 |
-
else:
|
| 416 |
-
x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
|
| 417 |
-
else:
|
| 418 |
-
# We include the dot product as float32, for consistency
|
| 419 |
-
# with the other implementations that include that step
|
| 420 |
-
# as part of the attention. Note that when using `autocast`,
|
| 421 |
-
# the einsums would be done as bfloat16, but the softmax
|
| 422 |
-
# would be done as bfloat16, so `attention_as_float32` will
|
| 423 |
-
# extend a bit the range of operations done in float32,
|
| 424 |
-
# although this should make no difference.
|
| 425 |
-
q = q / q.shape[-1] ** 0.5
|
| 426 |
-
key_layout = layout.replace('t', 'k')
|
| 427 |
-
query_layout = layout
|
| 428 |
-
if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
|
| 429 |
-
with torch.autocast(device_type=q.device.type, dtype=torch.float32):
|
| 430 |
-
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
|
| 431 |
-
else:
|
| 432 |
-
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
|
| 433 |
-
if attn_mask is not None:
|
| 434 |
-
pre_w = pre_w + attn_mask
|
| 435 |
-
w = torch.softmax(pre_w, dim=-1)
|
| 436 |
-
w = F.dropout(w, self.dropout, training=self.training).to(v)
|
| 437 |
-
# Key and value have the same format.
|
| 438 |
-
x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
|
| 439 |
-
x = x.to(dtype)
|
| 440 |
-
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
| 441 |
-
x = self.out_proj(x)
|
| 442 |
-
else:
|
| 443 |
-
key, value = self._complete_kv(key, value)
|
| 444 |
-
if self.attention_as_float32:
|
| 445 |
-
query, key, value = [x.float() for x in [query, key, value]]
|
| 446 |
-
x, _ = self.mha(
|
| 447 |
-
query, key, value, key_padding_mask,
|
| 448 |
-
need_weights, attn_mask, average_attn_weights)
|
| 449 |
-
x = x.to(dtype)
|
| 450 |
-
|
| 451 |
-
return x, None
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
| 455 |
-
"""TransformerLayer with Streaming / Causal support.
|
| 456 |
-
This also integrates cross_attention, when passing `cross_attention=True`,
|
| 457 |
-
rather than having two separate classes like in PyTorch.
|
| 458 |
-
|
| 459 |
-
Args:
|
| 460 |
-
d_model (int): Dimension of the data.
|
| 461 |
-
num_heads (int): Number of heads.
|
| 462 |
-
dim_feedforward (int): Intermediate dimension of FF module.
|
| 463 |
-
dropout (float): Dropout both for MHA and FF.
|
| 464 |
-
bias_ff (bool): Use bias for FF.
|
| 465 |
-
bias_attn (bool): Use bias for MHA.
|
| 466 |
-
causal (bool): Causal mask applied automatically.
|
| 467 |
-
past_context (int, optional): Receptive field for the causal mask, infinite if None.
|
| 468 |
-
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
| 469 |
-
memory_efficient (bool): Use xformers based memory efficient attention.
|
| 470 |
-
attention_as_float32 (bool): Perform the attention as float32
|
| 471 |
-
(especially important with memory_efficient as autocast won't do this automatically).
|
| 472 |
-
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
|
| 473 |
-
qk_layer_norm_cross (bool): Same for the cross attention.
|
| 474 |
-
cross_attention (bool): If True, expect to get secondary input for cross-attention.
|
| 475 |
-
Cross attention will use the default MHA, as it typically won't require
|
| 476 |
-
special treatment.
|
| 477 |
-
layer_scale (float, optional): If not None, LayerScale will be used with
|
| 478 |
-
the given value as initial scale.
|
| 479 |
-
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
| 480 |
-
attention_dropout (float, optional): If not None, separate the value of the dimension dropout
|
| 481 |
-
in FFN and of the attention dropout.
|
| 482 |
-
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
|
| 483 |
-
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
|
| 484 |
-
device (torch.device, optional): Device on which to initialize.
|
| 485 |
-
dtype (torch.dtype, optional): dtype to use.
|
| 486 |
-
**kwargs: See `nn.TransformerEncoderLayer`.
|
| 487 |
-
"""
|
| 488 |
-
def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
| 489 |
-
bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
|
| 490 |
-
past_context: tp.Optional[int] = None, custom: bool = False,
|
| 491 |
-
memory_efficient: bool = False, attention_as_float32: bool = False,
|
| 492 |
-
qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
|
| 493 |
-
cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
|
| 494 |
-
rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
|
| 495 |
-
kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
|
| 496 |
-
super().__init__(d_model, num_heads, dim_feedforward, dropout,
|
| 497 |
-
device=device, dtype=dtype, batch_first=True, **kwargs)
|
| 498 |
-
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 499 |
-
# Redefine self_attn to our streaming multi-head attention
|
| 500 |
-
attn_kwargs: tp.Dict[str, tp.Any] = {
|
| 501 |
-
'embed_dim': d_model,
|
| 502 |
-
'num_heads': num_heads,
|
| 503 |
-
'dropout': dropout if attention_dropout is None else attention_dropout,
|
| 504 |
-
'bias': bias_attn,
|
| 505 |
-
'custom': custom,
|
| 506 |
-
'memory_efficient': memory_efficient,
|
| 507 |
-
'attention_as_float32': attention_as_float32,
|
| 508 |
-
}
|
| 509 |
-
self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
|
| 510 |
-
causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
|
| 511 |
-
kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
|
| 512 |
-
# Redefine feedforward layers to expose bias parameter
|
| 513 |
-
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
|
| 514 |
-
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
|
| 515 |
-
|
| 516 |
-
self.layer_scale_1: nn.Module
|
| 517 |
-
self.layer_scale_2: nn.Module
|
| 518 |
-
if layer_scale is None:
|
| 519 |
-
self.layer_scale_1 = nn.Identity()
|
| 520 |
-
self.layer_scale_2 = nn.Identity()
|
| 521 |
-
else:
|
| 522 |
-
self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
|
| 523 |
-
self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
|
| 524 |
-
|
| 525 |
-
self.cross_attention: tp.Optional[nn.Module] = None
|
| 526 |
-
if cross_attention:
|
| 527 |
-
self.cross_attention = StreamingMultiheadAttention(
|
| 528 |
-
cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
|
| 529 |
-
**attn_kwargs, **factory_kwargs)
|
| 530 |
-
# Norm and dropout
|
| 531 |
-
self.dropout_cross = nn.Dropout(dropout)
|
| 532 |
-
# eps value matching that used in PyTorch reference implementation.
|
| 533 |
-
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
|
| 534 |
-
self.layer_scale_cross: nn.Module
|
| 535 |
-
if layer_scale is None:
|
| 536 |
-
self.layer_scale_cross = nn.Identity()
|
| 537 |
-
else:
|
| 538 |
-
self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
|
| 539 |
-
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
| 540 |
-
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
| 541 |
-
|
| 542 |
-
def _cross_attention_block(self, src: torch.Tensor,
|
| 543 |
-
cross_attention_src: torch.Tensor) -> torch.Tensor:
|
| 544 |
-
assert self.cross_attention is not None
|
| 545 |
-
# queries are from src, keys and values from cross_attention_src.
|
| 546 |
-
x = self.cross_attention(
|
| 547 |
-
src, cross_attention_src, cross_attention_src, need_weights=False)[0]
|
| 548 |
-
return self.dropout_cross(x) # type: ignore
|
| 549 |
-
|
| 550 |
-
def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore
|
| 551 |
-
src_key_padding_mask: tp.Optional[torch.Tensor] = None,
|
| 552 |
-
cross_attention_src: tp.Optional[torch.Tensor] = None):
|
| 553 |
-
if self.cross_attention is None:
|
| 554 |
-
assert cross_attention_src is None
|
| 555 |
-
else:
|
| 556 |
-
assert cross_attention_src is not None
|
| 557 |
-
x = src
|
| 558 |
-
if self.norm_first:
|
| 559 |
-
x = x + self.layer_scale_1(
|
| 560 |
-
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
|
| 561 |
-
if cross_attention_src is not None:
|
| 562 |
-
x = x + self.layer_scale_cross(
|
| 563 |
-
self._cross_attention_block(
|
| 564 |
-
self.norm_cross(x), cross_attention_src))
|
| 565 |
-
x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
|
| 566 |
-
else:
|
| 567 |
-
x = self.norm1(x + self.layer_scale_1(
|
| 568 |
-
self._sa_block(x, src_mask, src_key_padding_mask)))
|
| 569 |
-
if cross_attention_src is not None:
|
| 570 |
-
x = self.norm_cross(
|
| 571 |
-
x + self.layer_scale_cross(
|
| 572 |
-
self._cross_attention_block(src, cross_attention_src)))
|
| 573 |
-
x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
|
| 574 |
-
return x
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
class StreamingTransformer(StreamingModule):
|
| 578 |
-
"""Transformer with Streaming / Causal support.
|
| 579 |
-
|
| 580 |
-
Args:
|
| 581 |
-
d_model (int): Dimension of the data.
|
| 582 |
-
num_heads (int): Number of heads.
|
| 583 |
-
dim_feedforward (int): Intermediate dimension of FF module.
|
| 584 |
-
dropout (float): Dropout both for MHA and FF.
|
| 585 |
-
bias_ff (bool): Use bias for FF.
|
| 586 |
-
bias_attn (bool): Use bias for MHA.
|
| 587 |
-
causal (bool): Causal mask applied automatically.
|
| 588 |
-
past_context (int, optional): Receptive field for the causal mask, infinite if None.
|
| 589 |
-
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
| 590 |
-
memory_efficient (bool): Use xformers based memory efficient attention.
|
| 591 |
-
attention_as_float32 (bool): Perform the attention as float32
|
| 592 |
-
(especially important with memory_efficient as autocast won't do this automatically).
|
| 593 |
-
cross_attention (bool): If True, expect to get secondary input for cross-attention.
|
| 594 |
-
layer_scale (float, optional): If not None, LayerScale will be used
|
| 595 |
-
with the given value as initial scale.
|
| 596 |
-
positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
|
| 597 |
-
max_period (float): Maximum period of the time embedding.
|
| 598 |
-
positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
|
| 599 |
-
xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
|
| 600 |
-
lr (float, optional): learning rate override through the `make_optim_group` API.
|
| 601 |
-
weight_decay (float, optional): Weight_decay override through the `make_optim_group` API.
|
| 602 |
-
layer_class: (subclass of `StreamingTransformerLayer): class to use
|
| 603 |
-
to initialize the layers, allowing further customization outside of AudioCraft.
|
| 604 |
-
checkpointing (str): Checkpointing strategy to reduce memory usage.
|
| 605 |
-
No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
|
| 606 |
-
if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
|
| 607 |
-
minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
|
| 608 |
-
a policy for opting-out some operations of the checkpointing like
|
| 609 |
-
linear layers and attention, providing a middle ground between speed and memory.
|
| 610 |
-
device (torch.device, optional): Device on which to initialize.
|
| 611 |
-
dtype (torch.dtype, optional): dtype to use.
|
| 612 |
-
**kwargs: See `nn.TransformerEncoderLayer`.
|
| 613 |
-
"""
|
| 614 |
-
def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
|
| 615 |
-
dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
|
| 616 |
-
causal: bool = False, past_context: tp.Optional[int] = None,
|
| 617 |
-
custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
|
| 618 |
-
cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
|
| 619 |
-
positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
|
| 620 |
-
xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
|
| 621 |
-
layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
|
| 622 |
-
checkpointing: str = 'none', device=None, dtype=None, **kwargs):
|
| 623 |
-
super().__init__()
|
| 624 |
-
assert d_model % num_heads == 0
|
| 625 |
-
|
| 626 |
-
self.positional_embedding = positional_embedding
|
| 627 |
-
self.max_period = max_period
|
| 628 |
-
self.positional_scale = positional_scale
|
| 629 |
-
self.weight_decay = weight_decay
|
| 630 |
-
self.lr = lr
|
| 631 |
-
|
| 632 |
-
assert positional_embedding in ['sin', 'rope', 'sin_rope']
|
| 633 |
-
self.rope: tp.Optional[RotaryEmbedding] = None
|
| 634 |
-
if self.positional_embedding in ['rope', 'sin_rope']:
|
| 635 |
-
assert _is_custom(custom, memory_efficient)
|
| 636 |
-
self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
|
| 637 |
-
xpos=xpos, scale=positional_scale, device=device)
|
| 638 |
-
|
| 639 |
-
self.checkpointing = checkpointing
|
| 640 |
-
|
| 641 |
-
assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
|
| 642 |
-
if self.checkpointing.startswith('xformers'):
|
| 643 |
-
_verify_xformers_internal_compat()
|
| 644 |
-
|
| 645 |
-
self.layers = nn.ModuleList()
|
| 646 |
-
for idx in range(num_layers):
|
| 647 |
-
self.layers.append(
|
| 648 |
-
layer_class(
|
| 649 |
-
d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
|
| 650 |
-
dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
|
| 651 |
-
causal=causal, past_context=past_context, custom=custom,
|
| 652 |
-
memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
|
| 653 |
-
cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
|
| 654 |
-
device=device, dtype=dtype, **kwargs))
|
| 655 |
-
|
| 656 |
-
if self.checkpointing != 'none':
|
| 657 |
-
for layer in self.layers:
|
| 658 |
-
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
|
| 659 |
-
# backward hook inside of FSDP...
|
| 660 |
-
layer._magma_checkpointed = True # type: ignore
|
| 661 |
-
|
| 662 |
-
def _apply_layer(self, layer, *args, **kwargs):
|
| 663 |
-
method = self.checkpointing
|
| 664 |
-
if method == 'none':
|
| 665 |
-
return layer(*args, **kwargs)
|
| 666 |
-
elif method == 'torch':
|
| 667 |
-
return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
|
| 668 |
-
elif method.startswith('xformers'):
|
| 669 |
-
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
|
| 670 |
-
if method == 'xformers_default':
|
| 671 |
-
# those operations will be saved, and not recomputed.
|
| 672 |
-
# According to Francisco we can get smarter policies but this is a good start.
|
| 673 |
-
allow_list = [
|
| 674 |
-
"xformers.efficient_attention_forward_cutlass.default",
|
| 675 |
-
"xformers_flash.flash_fwd.default",
|
| 676 |
-
"aten.addmm.default",
|
| 677 |
-
"aten.mm.default",
|
| 678 |
-
]
|
| 679 |
-
elif method == 'xformers_mm':
|
| 680 |
-
# those operations will be saved, and not recomputed.
|
| 681 |
-
# According to Francisco we can get smarter policies but this is a good start.
|
| 682 |
-
allow_list = [
|
| 683 |
-
"aten.addmm.default",
|
| 684 |
-
"aten.mm.default",
|
| 685 |
-
]
|
| 686 |
-
else:
|
| 687 |
-
raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
|
| 688 |
-
policy_fn = _get_default_policy(allow_list)
|
| 689 |
-
return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
|
| 690 |
-
else:
|
| 691 |
-
raise ValueError(f"Checkpointing method {method} is unknown.")
|
| 692 |
-
|
| 693 |
-
def forward(self, x: torch.Tensor, *args, **kwargs):
|
| 694 |
-
B, T, C = x.shape
|
| 695 |
-
|
| 696 |
-
if 'offsets' in self._streaming_state:
|
| 697 |
-
offsets = self._streaming_state['offsets']
|
| 698 |
-
else:
|
| 699 |
-
offsets = torch.zeros(B, dtype=torch.long, device=x.device)
|
| 700 |
-
|
| 701 |
-
if self.positional_embedding in ['sin', 'sin_rope']:
|
| 702 |
-
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
| 703 |
-
positions = positions + offsets.view(-1, 1, 1)
|
| 704 |
-
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
|
| 705 |
-
x = x + self.positional_scale * pos_emb
|
| 706 |
-
|
| 707 |
-
for layer in self.layers:
|
| 708 |
-
x = self._apply_layer(layer, x, *args, **kwargs)
|
| 709 |
-
|
| 710 |
-
if self._is_streaming:
|
| 711 |
-
self._streaming_state['offsets'] = offsets + T
|
| 712 |
-
|
| 713 |
-
return x
|
| 714 |
-
|
| 715 |
-
def make_optim_group(self):
|
| 716 |
-
group = {"params": list(self.parameters())}
|
| 717 |
-
if self.lr is not None:
|
| 718 |
-
group["lr"] = self.lr
|
| 719 |
-
if self.weight_decay is not None:
|
| 720 |
-
group["weight_decay"] = self.weight_decay
|
| 721 |
-
return group
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
# special attention related function
|
| 725 |
-
|
| 726 |
-
def _verify_xformers_memory_efficient_compat():
|
| 727 |
-
try:
|
| 728 |
-
from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa
|
| 729 |
-
except ImportError:
|
| 730 |
-
raise ImportError(
|
| 731 |
-
"xformers is not installed. Please install it and try again.\n"
|
| 732 |
-
"To install on AWS and Azure, run \n"
|
| 733 |
-
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
|
| 734 |
-
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
|
| 735 |
-
"To install on FAIR Cluster, run \n"
|
| 736 |
-
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
|
| 737 |
-
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
def _verify_xformers_internal_compat():
|
| 741 |
-
try:
|
| 742 |
-
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa
|
| 743 |
-
except ImportError:
|
| 744 |
-
raise ImportError(
|
| 745 |
-
"Francisco's fairinternal xformers is not installed. Please install it and try again.\n"
|
| 746 |
-
"To install on AWS and Azure, run \n"
|
| 747 |
-
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
|
| 748 |
-
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
|
| 749 |
-
"To install on FAIR Cluster, run \n"
|
| 750 |
-
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
|
| 751 |
-
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
def _is_custom(custom: bool, memory_efficient: bool):
|
| 755 |
-
return custom or memory_efficient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/modules/unet_transformer.py
DELETED
|
@@ -1,67 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import typing as tp
|
| 3 |
-
from .transformer import StreamingTransformer, create_sin_embedding
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class UnetTransformer(StreamingTransformer):
|
| 7 |
-
"""U-net Transformer for processing sequences with optional skip connections.
|
| 8 |
-
This transformer architecture incorporates U-net style skip connections
|
| 9 |
-
between layers, which can be optionally enabled. It inherits from a
|
| 10 |
-
StreamingTransformer.
|
| 11 |
-
|
| 12 |
-
Args:
|
| 13 |
-
d_model (int): Dimension of the model, typically the number of expected features in the input.
|
| 14 |
-
num_layers (int): Total number of layers in the transformer.
|
| 15 |
-
skip_connections (bool, optional): Flag to determine whether skip connections should be used.
|
| 16 |
-
Defaults to False.
|
| 17 |
-
layer_dropout_p (float, Optional): if given, defined bernoulli prob. to drop a skip connection (in training).
|
| 18 |
-
**kwargs: Additional keyword arguments inherited from `nn.StreamingTransformer`.
|
| 19 |
-
"""
|
| 20 |
-
def __init__(self, d_model: int, num_layers: int, skip_connections: bool = False,
|
| 21 |
-
layer_dropout_p: tp.Optional[float] = None, **kwargs):
|
| 22 |
-
super().__init__(d_model=d_model,
|
| 23 |
-
num_layers=num_layers,
|
| 24 |
-
**kwargs)
|
| 25 |
-
self.skip_connect = skip_connections
|
| 26 |
-
if self.skip_connect:
|
| 27 |
-
self.skip_projections = torch.nn.ModuleList([torch.nn.Linear(d_model * 2, d_model)
|
| 28 |
-
for _ in range(num_layers // 2)])
|
| 29 |
-
self.num_layers = num_layers
|
| 30 |
-
self.layer_drop_p = max(min(layer_dropout_p, 1.), 0.) if layer_dropout_p is not None else 0.0
|
| 31 |
-
|
| 32 |
-
def forward(self, x: torch.Tensor, *args, **kwargs):
|
| 33 |
-
B, T, C = x.shape
|
| 34 |
-
|
| 35 |
-
if 'offsets' in self._streaming_state:
|
| 36 |
-
offsets = self._streaming_state['offsets']
|
| 37 |
-
else:
|
| 38 |
-
offsets = torch.zeros(B, dtype=torch.long, device=x.device)
|
| 39 |
-
|
| 40 |
-
if self.positional_embedding in ['sin', 'sin_rope']:
|
| 41 |
-
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
| 42 |
-
positions = positions + offsets.view(-1, 1, 1)
|
| 43 |
-
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
|
| 44 |
-
x = x + self.positional_scale * pos_emb
|
| 45 |
-
|
| 46 |
-
skip_connections: tp.List[torch.Tensor] = []
|
| 47 |
-
|
| 48 |
-
for i, layer in enumerate(self.layers):
|
| 49 |
-
if self.skip_connect and i >= self.num_layers // 2:
|
| 50 |
-
|
| 51 |
-
# in the second half of the layers, add residual connection
|
| 52 |
-
# and linearly project the concatenated features back to d_model
|
| 53 |
-
x = torch.cat([x, skip_connections.pop()], dim=-1)
|
| 54 |
-
x = self.skip_projections[i % len(self.skip_projections)](x)
|
| 55 |
-
|
| 56 |
-
x = self._apply_layer(layer, x, *args, **kwargs)
|
| 57 |
-
|
| 58 |
-
if self.skip_connect and i < self.num_layers // 2:
|
| 59 |
-
if self.training and torch.rand(1,) < self.layer_drop_p: # drop skip
|
| 60 |
-
skip_connections.append(torch.zeros_like(x))
|
| 61 |
-
else:
|
| 62 |
-
skip_connections.append(x)
|
| 63 |
-
|
| 64 |
-
if self._is_streaming:
|
| 65 |
-
self._streaming_state['offsets'] = offsets + T
|
| 66 |
-
|
| 67 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/py.typed
DELETED
|
File without changes
|
audiocraft/quantization/__init__.py
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
# flake8: noqa
|
| 8 |
-
from .vq import ResidualVectorQuantizer
|
| 9 |
-
from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/quantization/base.py
DELETED
|
@@ -1,107 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Base class for all quantizers.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
from dataclasses import dataclass, field
|
| 12 |
-
import typing as tp
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
from torch import nn
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
@dataclass
|
| 19 |
-
class QuantizedResult:
|
| 20 |
-
x: torch.Tensor
|
| 21 |
-
codes: torch.Tensor
|
| 22 |
-
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
| 23 |
-
penalty: tp.Optional[torch.Tensor] = None
|
| 24 |
-
metrics: dict = field(default_factory=dict)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class BaseQuantizer(nn.Module):
|
| 28 |
-
"""Base class for quantizers.
|
| 29 |
-
"""
|
| 30 |
-
|
| 31 |
-
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
|
| 32 |
-
"""
|
| 33 |
-
Given input tensor x, returns first the quantized (or approximately quantized)
|
| 34 |
-
representation along with quantized codes, bandwidth, and any penalty term for the loss.
|
| 35 |
-
Finally, this returns a dict of metrics to update logging etc.
|
| 36 |
-
Frame rate must be passed so that the bandwidth is properly computed.
|
| 37 |
-
"""
|
| 38 |
-
raise NotImplementedError()
|
| 39 |
-
|
| 40 |
-
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 41 |
-
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
| 42 |
-
"""
|
| 43 |
-
raise NotImplementedError()
|
| 44 |
-
|
| 45 |
-
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 46 |
-
"""Decode the given codes to the quantized representation.
|
| 47 |
-
"""
|
| 48 |
-
raise NotImplementedError()
|
| 49 |
-
|
| 50 |
-
@property
|
| 51 |
-
def total_codebooks(self):
|
| 52 |
-
"""Total number of codebooks.
|
| 53 |
-
"""
|
| 54 |
-
raise NotImplementedError()
|
| 55 |
-
|
| 56 |
-
@property
|
| 57 |
-
def num_codebooks(self):
|
| 58 |
-
"""Number of active codebooks.
|
| 59 |
-
"""
|
| 60 |
-
raise NotImplementedError()
|
| 61 |
-
|
| 62 |
-
def set_num_codebooks(self, n: int):
|
| 63 |
-
"""Set the number of active codebooks.
|
| 64 |
-
"""
|
| 65 |
-
raise NotImplementedError()
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
class DummyQuantizer(BaseQuantizer):
|
| 69 |
-
"""Fake quantizer that actually does not perform any quantization.
|
| 70 |
-
"""
|
| 71 |
-
def __init__(self):
|
| 72 |
-
super().__init__()
|
| 73 |
-
|
| 74 |
-
def forward(self, x: torch.Tensor, frame_rate: int):
|
| 75 |
-
q = x.unsqueeze(1)
|
| 76 |
-
return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
|
| 77 |
-
|
| 78 |
-
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 79 |
-
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
| 80 |
-
In the case of the DummyQuantizer, the codes are actually identical
|
| 81 |
-
to the input and resulting quantized representation as no quantization is done.
|
| 82 |
-
"""
|
| 83 |
-
return x.unsqueeze(1)
|
| 84 |
-
|
| 85 |
-
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 86 |
-
"""Decode the given codes to the quantized representation.
|
| 87 |
-
In the case of the DummyQuantizer, the codes are actually identical
|
| 88 |
-
to the input and resulting quantized representation as no quantization is done.
|
| 89 |
-
"""
|
| 90 |
-
return codes.squeeze(1)
|
| 91 |
-
|
| 92 |
-
@property
|
| 93 |
-
def total_codebooks(self):
|
| 94 |
-
"""Total number of codebooks.
|
| 95 |
-
"""
|
| 96 |
-
return 1
|
| 97 |
-
|
| 98 |
-
@property
|
| 99 |
-
def num_codebooks(self):
|
| 100 |
-
"""Total number of codebooks.
|
| 101 |
-
"""
|
| 102 |
-
return self.total_codebooks
|
| 103 |
-
|
| 104 |
-
def set_num_codebooks(self, n: int):
|
| 105 |
-
"""Set the number of active codebooks.
|
| 106 |
-
"""
|
| 107 |
-
raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/quantization/core_vq.py
DELETED
|
@@ -1,405 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import typing as tp
|
| 8 |
-
|
| 9 |
-
from einops import rearrange, repeat
|
| 10 |
-
import flashy
|
| 11 |
-
import torch
|
| 12 |
-
from torch import nn, einsum
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def exists(val: tp.Optional[tp.Any]) -> bool:
|
| 17 |
-
return val is not None
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
| 21 |
-
return val if exists(val) else d
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def l2norm(t):
|
| 25 |
-
return F.normalize(t, p=2, dim=-1)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def ema_inplace(moving_avg, new, decay: float):
|
| 29 |
-
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
| 33 |
-
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def uniform_init(*shape: int):
|
| 37 |
-
t = torch.empty(shape)
|
| 38 |
-
nn.init.kaiming_uniform_(t)
|
| 39 |
-
return t
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def sample_vectors(samples, num: int):
|
| 43 |
-
num_samples, device = samples.shape[0], samples.device
|
| 44 |
-
|
| 45 |
-
if num_samples >= num:
|
| 46 |
-
indices = torch.randperm(num_samples, device=device)[:num]
|
| 47 |
-
else:
|
| 48 |
-
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 49 |
-
|
| 50 |
-
return samples[indices]
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
| 54 |
-
dim, dtype = samples.shape[-1], samples.dtype
|
| 55 |
-
|
| 56 |
-
means = sample_vectors(samples, num_clusters)
|
| 57 |
-
|
| 58 |
-
for _ in range(num_iters):
|
| 59 |
-
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
| 60 |
-
means, "c d -> () c d"
|
| 61 |
-
)
|
| 62 |
-
dists = -(diffs ** 2).sum(dim=-1)
|
| 63 |
-
|
| 64 |
-
buckets = dists.max(dim=-1).indices
|
| 65 |
-
bins = torch.bincount(buckets, minlength=num_clusters)
|
| 66 |
-
zero_mask = bins == 0
|
| 67 |
-
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
| 68 |
-
|
| 69 |
-
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
| 70 |
-
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
| 71 |
-
new_means = new_means / bins_min_clamped[..., None]
|
| 72 |
-
|
| 73 |
-
means = torch.where(zero_mask[..., None], means, new_means)
|
| 74 |
-
|
| 75 |
-
return means, bins
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def orthogonal_loss_fn(t):
|
| 79 |
-
# eq (2) from https://arxiv.org/abs/2112.00384
|
| 80 |
-
n = t.shape[0]
|
| 81 |
-
normed_codes = l2norm(t)
|
| 82 |
-
identity = torch.eye(n, device=t.device)
|
| 83 |
-
cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
|
| 84 |
-
return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
class EuclideanCodebook(nn.Module):
|
| 88 |
-
"""Codebook with Euclidean distance.
|
| 89 |
-
|
| 90 |
-
Args:
|
| 91 |
-
dim (int): Dimension.
|
| 92 |
-
codebook_size (int): Codebook size.
|
| 93 |
-
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
| 94 |
-
If set to true, run the k-means algorithm on the first training batch and use
|
| 95 |
-
the learned centroids as initialization.
|
| 96 |
-
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
| 97 |
-
decay (float): Decay for exponential moving average over the codebooks.
|
| 98 |
-
epsilon (float): Epsilon value for numerical stability.
|
| 99 |
-
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 100 |
-
that have an exponential moving average cluster size less than the specified threshold with
|
| 101 |
-
randomly selected vector from the current batch.
|
| 102 |
-
"""
|
| 103 |
-
def __init__(
|
| 104 |
-
self,
|
| 105 |
-
dim: int,
|
| 106 |
-
codebook_size: int,
|
| 107 |
-
kmeans_init: int = False,
|
| 108 |
-
kmeans_iters: int = 10,
|
| 109 |
-
decay: float = 0.8,
|
| 110 |
-
epsilon: float = 1e-5,
|
| 111 |
-
threshold_ema_dead_code: int = 2,
|
| 112 |
-
):
|
| 113 |
-
super().__init__()
|
| 114 |
-
self.decay = decay
|
| 115 |
-
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
| 116 |
-
embed = init_fn(codebook_size, dim)
|
| 117 |
-
|
| 118 |
-
self.codebook_size = codebook_size
|
| 119 |
-
|
| 120 |
-
self.kmeans_iters = kmeans_iters
|
| 121 |
-
self.epsilon = epsilon
|
| 122 |
-
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 123 |
-
|
| 124 |
-
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
| 125 |
-
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
| 126 |
-
self.register_buffer("embed", embed)
|
| 127 |
-
self.register_buffer("embed_avg", embed.clone())
|
| 128 |
-
|
| 129 |
-
@torch.jit.ignore
|
| 130 |
-
def init_embed_(self, data):
|
| 131 |
-
if self.inited:
|
| 132 |
-
return
|
| 133 |
-
|
| 134 |
-
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
| 135 |
-
self.embed.data.copy_(embed)
|
| 136 |
-
self.embed_avg.data.copy_(embed.clone())
|
| 137 |
-
self.cluster_size.data.copy_(cluster_size)
|
| 138 |
-
self.inited.data.copy_(torch.Tensor([True]))
|
| 139 |
-
# Make sure all buffers across workers are in sync after initialization
|
| 140 |
-
flashy.distrib.broadcast_tensors(self.buffers())
|
| 141 |
-
|
| 142 |
-
def replace_(self, samples, mask):
|
| 143 |
-
modified_codebook = torch.where(
|
| 144 |
-
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
| 145 |
-
)
|
| 146 |
-
self.embed.data.copy_(modified_codebook)
|
| 147 |
-
|
| 148 |
-
def expire_codes_(self, batch_samples):
|
| 149 |
-
if self.threshold_ema_dead_code == 0:
|
| 150 |
-
return
|
| 151 |
-
|
| 152 |
-
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
| 153 |
-
if not torch.any(expired_codes):
|
| 154 |
-
return
|
| 155 |
-
|
| 156 |
-
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
| 157 |
-
self.replace_(batch_samples, mask=expired_codes)
|
| 158 |
-
flashy.distrib.broadcast_tensors(self.buffers())
|
| 159 |
-
|
| 160 |
-
def preprocess(self, x):
|
| 161 |
-
x = rearrange(x, "... d -> (...) d")
|
| 162 |
-
return x
|
| 163 |
-
|
| 164 |
-
def quantize(self, x):
|
| 165 |
-
embed = self.embed.t()
|
| 166 |
-
dist = -(
|
| 167 |
-
x.pow(2).sum(1, keepdim=True)
|
| 168 |
-
- 2 * x @ embed
|
| 169 |
-
+ embed.pow(2).sum(0, keepdim=True)
|
| 170 |
-
)
|
| 171 |
-
embed_ind = dist.max(dim=-1).indices
|
| 172 |
-
return embed_ind
|
| 173 |
-
|
| 174 |
-
def postprocess_emb(self, embed_ind, shape):
|
| 175 |
-
return embed_ind.view(*shape[:-1])
|
| 176 |
-
|
| 177 |
-
def dequantize(self, embed_ind):
|
| 178 |
-
quantize = F.embedding(embed_ind, self.embed)
|
| 179 |
-
return quantize
|
| 180 |
-
|
| 181 |
-
def encode(self, x):
|
| 182 |
-
shape = x.shape
|
| 183 |
-
# pre-process
|
| 184 |
-
x = self.preprocess(x)
|
| 185 |
-
# quantize
|
| 186 |
-
embed_ind = self.quantize(x)
|
| 187 |
-
# post-process
|
| 188 |
-
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 189 |
-
return embed_ind
|
| 190 |
-
|
| 191 |
-
def decode(self, embed_ind):
|
| 192 |
-
quantize = self.dequantize(embed_ind)
|
| 193 |
-
return quantize
|
| 194 |
-
|
| 195 |
-
def forward(self, x):
|
| 196 |
-
shape, dtype = x.shape, x.dtype
|
| 197 |
-
x = self.preprocess(x)
|
| 198 |
-
self.init_embed_(x)
|
| 199 |
-
|
| 200 |
-
embed_ind = self.quantize(x)
|
| 201 |
-
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
| 202 |
-
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 203 |
-
quantize = self.dequantize(embed_ind)
|
| 204 |
-
|
| 205 |
-
if self.training:
|
| 206 |
-
# We do the expiry of code at that point as buffers are in sync
|
| 207 |
-
# and all the workers will take the same decision.
|
| 208 |
-
self.expire_codes_(x)
|
| 209 |
-
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
| 210 |
-
embed_sum = x.t() @ embed_onehot
|
| 211 |
-
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
| 212 |
-
cluster_size = (
|
| 213 |
-
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
| 214 |
-
* self.cluster_size.sum()
|
| 215 |
-
)
|
| 216 |
-
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
| 217 |
-
self.embed.data.copy_(embed_normalized)
|
| 218 |
-
|
| 219 |
-
return quantize, embed_ind
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
class VectorQuantization(nn.Module):
|
| 223 |
-
"""Vector quantization implementation.
|
| 224 |
-
Currently supports only euclidean distance.
|
| 225 |
-
|
| 226 |
-
Args:
|
| 227 |
-
dim (int): Dimension
|
| 228 |
-
codebook_size (int): Codebook size
|
| 229 |
-
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
| 230 |
-
decay (float): Decay for exponential moving average over the codebooks.
|
| 231 |
-
epsilon (float): Epsilon value for numerical stability.
|
| 232 |
-
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 233 |
-
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 234 |
-
threshold_ema_dead_code (int):
|
| 235 |
-
channels_last (bool): Channels are the last dimension in the input tensors.
|
| 236 |
-
commitment_weight (float): Weight for commitment loss.
|
| 237 |
-
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
| 238 |
-
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
| 239 |
-
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
|
| 240 |
-
for orthogonal regularization.
|
| 241 |
-
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 242 |
-
that have an exponential moving average cluster size less than the specified threshold with
|
| 243 |
-
randomly selected vector from the current batch.
|
| 244 |
-
"""
|
| 245 |
-
def __init__(
|
| 246 |
-
self,
|
| 247 |
-
dim: int,
|
| 248 |
-
codebook_size: int,
|
| 249 |
-
codebook_dim: tp.Optional[int] = None,
|
| 250 |
-
decay: float = 0.8,
|
| 251 |
-
epsilon: float = 1e-5,
|
| 252 |
-
kmeans_init: bool = False,
|
| 253 |
-
kmeans_iters: int = 10,
|
| 254 |
-
threshold_ema_dead_code: int = 2,
|
| 255 |
-
channels_last: bool = False,
|
| 256 |
-
commitment_weight: float = 1.,
|
| 257 |
-
orthogonal_reg_weight: float = 0.0,
|
| 258 |
-
orthogonal_reg_active_codes_only: bool = False,
|
| 259 |
-
orthogonal_reg_max_codes: tp.Optional[int] = None,
|
| 260 |
-
):
|
| 261 |
-
super().__init__()
|
| 262 |
-
_codebook_dim: int = default(codebook_dim, dim)
|
| 263 |
-
|
| 264 |
-
requires_projection = _codebook_dim != dim
|
| 265 |
-
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
|
| 266 |
-
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
|
| 267 |
-
|
| 268 |
-
self.epsilon = epsilon
|
| 269 |
-
self.commitment_weight = commitment_weight
|
| 270 |
-
|
| 271 |
-
self.orthogonal_reg_weight = orthogonal_reg_weight
|
| 272 |
-
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
| 273 |
-
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
| 274 |
-
|
| 275 |
-
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
| 276 |
-
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
| 277 |
-
decay=decay, epsilon=epsilon,
|
| 278 |
-
threshold_ema_dead_code=threshold_ema_dead_code)
|
| 279 |
-
self.codebook_size = codebook_size
|
| 280 |
-
|
| 281 |
-
self.channels_last = channels_last
|
| 282 |
-
|
| 283 |
-
@property
|
| 284 |
-
def codebook(self):
|
| 285 |
-
return self._codebook.embed
|
| 286 |
-
|
| 287 |
-
@property
|
| 288 |
-
def inited(self):
|
| 289 |
-
return self._codebook.inited
|
| 290 |
-
|
| 291 |
-
def _preprocess(self, x):
|
| 292 |
-
if not self.channels_last:
|
| 293 |
-
x = rearrange(x, "b d n -> b n d")
|
| 294 |
-
return x
|
| 295 |
-
|
| 296 |
-
def _postprocess(self, quantize):
|
| 297 |
-
if not self.channels_last:
|
| 298 |
-
quantize = rearrange(quantize, "b n d -> b d n")
|
| 299 |
-
return quantize
|
| 300 |
-
|
| 301 |
-
def encode(self, x):
|
| 302 |
-
x = self._preprocess(x)
|
| 303 |
-
x = self.project_in(x)
|
| 304 |
-
embed_in = self._codebook.encode(x)
|
| 305 |
-
return embed_in
|
| 306 |
-
|
| 307 |
-
def decode(self, embed_ind):
|
| 308 |
-
quantize = self._codebook.decode(embed_ind)
|
| 309 |
-
quantize = self.project_out(quantize)
|
| 310 |
-
quantize = self._postprocess(quantize)
|
| 311 |
-
return quantize
|
| 312 |
-
|
| 313 |
-
def forward(self, x):
|
| 314 |
-
device = x.device
|
| 315 |
-
x = self._preprocess(x)
|
| 316 |
-
|
| 317 |
-
x = self.project_in(x)
|
| 318 |
-
quantize, embed_ind = self._codebook(x)
|
| 319 |
-
|
| 320 |
-
if self.training:
|
| 321 |
-
quantize = x + (quantize - x).detach()
|
| 322 |
-
|
| 323 |
-
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
| 324 |
-
|
| 325 |
-
if self.training:
|
| 326 |
-
if self.commitment_weight > 0:
|
| 327 |
-
commit_loss = F.mse_loss(quantize.detach(), x)
|
| 328 |
-
loss = loss + commit_loss * self.commitment_weight
|
| 329 |
-
|
| 330 |
-
if self.orthogonal_reg_weight > 0:
|
| 331 |
-
codebook = self.codebook
|
| 332 |
-
|
| 333 |
-
if self.orthogonal_reg_active_codes_only:
|
| 334 |
-
# only calculate orthogonal loss for the activated codes for this batch
|
| 335 |
-
unique_code_ids = torch.unique(embed_ind)
|
| 336 |
-
codebook = codebook[unique_code_ids]
|
| 337 |
-
|
| 338 |
-
num_codes = codebook.shape[0]
|
| 339 |
-
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
|
| 340 |
-
rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
|
| 341 |
-
codebook = codebook[rand_ids]
|
| 342 |
-
|
| 343 |
-
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
|
| 344 |
-
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
|
| 345 |
-
|
| 346 |
-
quantize = self.project_out(quantize)
|
| 347 |
-
quantize = self._postprocess(quantize)
|
| 348 |
-
|
| 349 |
-
return quantize, embed_ind, loss
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
class ResidualVectorQuantization(nn.Module):
|
| 353 |
-
"""Residual vector quantization implementation.
|
| 354 |
-
|
| 355 |
-
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
| 356 |
-
"""
|
| 357 |
-
def __init__(self, *, num_quantizers, **kwargs):
|
| 358 |
-
super().__init__()
|
| 359 |
-
self.layers = nn.ModuleList(
|
| 360 |
-
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
| 361 |
-
)
|
| 362 |
-
|
| 363 |
-
def forward(self, x, n_q: tp.Optional[int] = None):
|
| 364 |
-
quantized_out = 0.0
|
| 365 |
-
residual = x
|
| 366 |
-
|
| 367 |
-
all_losses = []
|
| 368 |
-
all_indices = []
|
| 369 |
-
|
| 370 |
-
n_q = n_q or len(self.layers)
|
| 371 |
-
|
| 372 |
-
for i, layer in enumerate(self.layers[:n_q]):
|
| 373 |
-
quantized, indices, loss = layer(residual)
|
| 374 |
-
quantized = quantized.detach()
|
| 375 |
-
residual = residual - quantized
|
| 376 |
-
quantized_out = quantized_out + quantized
|
| 377 |
-
all_indices.append(indices)
|
| 378 |
-
all_losses.append(loss)
|
| 379 |
-
|
| 380 |
-
if self.training:
|
| 381 |
-
# Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
|
| 382 |
-
quantized_out = x + (quantized_out - x).detach()
|
| 383 |
-
|
| 384 |
-
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
| 385 |
-
return quantized_out, out_indices, out_losses
|
| 386 |
-
|
| 387 |
-
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
| 388 |
-
residual = x
|
| 389 |
-
all_indices = []
|
| 390 |
-
n_q = n_q or len(self.layers)
|
| 391 |
-
for layer in self.layers[:n_q]:
|
| 392 |
-
indices = layer.encode(residual)
|
| 393 |
-
quantized = layer.decode(indices)
|
| 394 |
-
residual = residual - quantized
|
| 395 |
-
all_indices.append(indices)
|
| 396 |
-
out_indices = torch.stack(all_indices)
|
| 397 |
-
return out_indices
|
| 398 |
-
|
| 399 |
-
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
| 400 |
-
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
| 401 |
-
for i, indices in enumerate(q_indices):
|
| 402 |
-
layer = self.layers[i]
|
| 403 |
-
quantized = layer.decode(indices)
|
| 404 |
-
quantized_out = quantized_out + quantized
|
| 405 |
-
return quantized_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/quantization/vq.py
DELETED
|
@@ -1,116 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import math
|
| 8 |
-
import typing as tp
|
| 9 |
-
|
| 10 |
-
import torch
|
| 11 |
-
|
| 12 |
-
from .base import BaseQuantizer, QuantizedResult
|
| 13 |
-
from .core_vq import ResidualVectorQuantization
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class ResidualVectorQuantizer(BaseQuantizer):
|
| 17 |
-
"""Residual Vector Quantizer.
|
| 18 |
-
|
| 19 |
-
Args:
|
| 20 |
-
dimension (int): Dimension of the codebooks.
|
| 21 |
-
n_q (int): Number of residual vector quantizers used.
|
| 22 |
-
q_dropout (bool): Random quantizer drop out at train time.
|
| 23 |
-
bins (int): Codebook size.
|
| 24 |
-
decay (float): Decay for exponential moving average over the codebooks.
|
| 25 |
-
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 26 |
-
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 27 |
-
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 28 |
-
that have an exponential moving average cluster size less than the specified threshold with
|
| 29 |
-
randomly selected vector from the current batch.
|
| 30 |
-
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
| 31 |
-
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
| 32 |
-
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
|
| 33 |
-
for orthogonal regulariation.
|
| 34 |
-
"""
|
| 35 |
-
def __init__(
|
| 36 |
-
self,
|
| 37 |
-
dimension: int = 256,
|
| 38 |
-
n_q: int = 8,
|
| 39 |
-
q_dropout: bool = False,
|
| 40 |
-
bins: int = 1024,
|
| 41 |
-
decay: float = 0.99,
|
| 42 |
-
kmeans_init: bool = True,
|
| 43 |
-
kmeans_iters: int = 10,
|
| 44 |
-
threshold_ema_dead_code: int = 2,
|
| 45 |
-
orthogonal_reg_weight: float = 0.0,
|
| 46 |
-
orthogonal_reg_active_codes_only: bool = False,
|
| 47 |
-
orthogonal_reg_max_codes: tp.Optional[int] = None,
|
| 48 |
-
):
|
| 49 |
-
super().__init__()
|
| 50 |
-
self.max_n_q = n_q
|
| 51 |
-
self.n_q = n_q
|
| 52 |
-
self.q_dropout = q_dropout
|
| 53 |
-
self.dimension = dimension
|
| 54 |
-
self.bins = bins
|
| 55 |
-
self.decay = decay
|
| 56 |
-
self.kmeans_init = kmeans_init
|
| 57 |
-
self.kmeans_iters = kmeans_iters
|
| 58 |
-
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 59 |
-
self.orthogonal_reg_weight = orthogonal_reg_weight
|
| 60 |
-
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
| 61 |
-
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
| 62 |
-
self.vq = ResidualVectorQuantization(
|
| 63 |
-
dim=self.dimension,
|
| 64 |
-
codebook_size=self.bins,
|
| 65 |
-
num_quantizers=self.n_q,
|
| 66 |
-
decay=self.decay,
|
| 67 |
-
kmeans_init=self.kmeans_init,
|
| 68 |
-
kmeans_iters=self.kmeans_iters,
|
| 69 |
-
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
| 70 |
-
orthogonal_reg_weight=self.orthogonal_reg_weight,
|
| 71 |
-
orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
|
| 72 |
-
orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
|
| 73 |
-
channels_last=False
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
def forward(self, x: torch.Tensor, frame_rate: int):
|
| 77 |
-
n_q = self.n_q
|
| 78 |
-
if self.training and self.q_dropout:
|
| 79 |
-
n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
|
| 80 |
-
bw_per_q = math.log2(self.bins) * frame_rate / 1000
|
| 81 |
-
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
|
| 82 |
-
codes = codes.transpose(0, 1)
|
| 83 |
-
# codes is [B, K, T], with T frames, K nb of codebooks.
|
| 84 |
-
bw = torch.tensor(n_q * bw_per_q).to(x)
|
| 85 |
-
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
|
| 86 |
-
|
| 87 |
-
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
-
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
|
| 89 |
-
The RVQ encode method sets the appropriate number of quantizer to use
|
| 90 |
-
and returns indices for each quantizer.
|
| 91 |
-
"""
|
| 92 |
-
n_q = self.n_q
|
| 93 |
-
codes = self.vq.encode(x, n_q=n_q)
|
| 94 |
-
codes = codes.transpose(0, 1)
|
| 95 |
-
# codes is [B, K, T], with T frames, K nb of codebooks.
|
| 96 |
-
return codes
|
| 97 |
-
|
| 98 |
-
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 99 |
-
"""Decode the given codes to the quantized representation.
|
| 100 |
-
"""
|
| 101 |
-
# codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
|
| 102 |
-
codes = codes.transpose(0, 1)
|
| 103 |
-
quantized = self.vq.decode(codes)
|
| 104 |
-
return quantized
|
| 105 |
-
|
| 106 |
-
@property
|
| 107 |
-
def total_codebooks(self):
|
| 108 |
-
return self.max_n_q
|
| 109 |
-
|
| 110 |
-
@property
|
| 111 |
-
def num_codebooks(self):
|
| 112 |
-
return self.n_q
|
| 113 |
-
|
| 114 |
-
def set_num_codebooks(self, n: int):
|
| 115 |
-
assert n > 0 and n <= self.max_n_q
|
| 116 |
-
self.n_q = n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/autocast.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class TorchAutocast:
|
| 11 |
-
"""TorchAutocast utility class.
|
| 12 |
-
Allows you to enable and disable autocast. This is specially useful
|
| 13 |
-
when dealing with different architectures and clusters with different
|
| 14 |
-
levels of support.
|
| 15 |
-
|
| 16 |
-
Args:
|
| 17 |
-
enabled (bool): Whether to enable torch.autocast or not.
|
| 18 |
-
args: Additional args for torch.autocast.
|
| 19 |
-
kwargs: Additional kwargs for torch.autocast
|
| 20 |
-
"""
|
| 21 |
-
def __init__(self, enabled: bool, *args, **kwargs):
|
| 22 |
-
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
|
| 23 |
-
|
| 24 |
-
def __enter__(self):
|
| 25 |
-
if self.autocast is None:
|
| 26 |
-
return
|
| 27 |
-
try:
|
| 28 |
-
self.autocast.__enter__()
|
| 29 |
-
except RuntimeError:
|
| 30 |
-
device = self.autocast.device
|
| 31 |
-
dtype = self.autocast.fast_dtype
|
| 32 |
-
raise RuntimeError(
|
| 33 |
-
f"There was an error autocasting with dtype={dtype} device={device}\n"
|
| 34 |
-
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
def __exit__(self, *args, **kwargs):
|
| 38 |
-
if self.autocast is None:
|
| 39 |
-
return
|
| 40 |
-
self.autocast.__exit__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/cache.py
DELETED
|
@@ -1,324 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
-
from collections import deque
|
| 9 |
-
from functools import partial
|
| 10 |
-
from hashlib import sha1
|
| 11 |
-
import logging
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
import sys
|
| 14 |
-
import typing as tp
|
| 15 |
-
import zipfile
|
| 16 |
-
|
| 17 |
-
import flashy
|
| 18 |
-
import torch
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
logger = logging.getLogger(__name__)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor:
|
| 25 |
-
"""Utility function for the EmbeddingCache, returning the full embedding without any chunking.
|
| 26 |
-
This method can be used in case there is no need in extracting a chunk of the full embedding
|
| 27 |
-
read from the cache.
|
| 28 |
-
|
| 29 |
-
Args:
|
| 30 |
-
full_embed (torch.Tensor): The full embedding.
|
| 31 |
-
x (any): Batch object from which the full embedding is derived.
|
| 32 |
-
idx (torch.Tensor): Index of object to consider in the batch object.
|
| 33 |
-
Returns:
|
| 34 |
-
full_embed (torch.Tensor): The full embedding
|
| 35 |
-
"""
|
| 36 |
-
return full_embed.to(device)
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class EmbeddingCache:
|
| 40 |
-
"""Cache around embeddings computation for faster execution.
|
| 41 |
-
The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API
|
| 42 |
-
to retrieve the pre-computed embeddings on full inputs and extract only a given chunk
|
| 43 |
-
using a user-provided function. When the cache is warm (all embeddings are pre-computed),
|
| 44 |
-
the EmbeddingCache allows for faster training as it removes the need of computing the embeddings.
|
| 45 |
-
Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint
|
| 46 |
-
and synchronization points in the forward calls.
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk.
|
| 50 |
-
device (str or torch.device): Device on which the embedding is returned.
|
| 51 |
-
compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute
|
| 52 |
-
the embedding from a given object and path. This user provided function can compute the
|
| 53 |
-
embedding from the provided object or using the provided path as entry point. The last parameter
|
| 54 |
-
specify the index corresponding to the current embedding in the object that can represent batch metadata.
|
| 55 |
-
extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract
|
| 56 |
-
the desired embedding chunk from the full embedding loaded from the cache. The last parameter
|
| 57 |
-
specify the index corresponding to the current embedding in the object that can represent batch metadata.
|
| 58 |
-
If not specified, will return the full embedding unmodified.
|
| 59 |
-
"""
|
| 60 |
-
def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device],
|
| 61 |
-
compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
|
| 62 |
-
extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
|
| 63 |
-
self.cache_path = Path(cache_path)
|
| 64 |
-
self.device = device
|
| 65 |
-
self._compute_embed_fn = compute_embed_fn
|
| 66 |
-
self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]
|
| 67 |
-
if extract_embed_fn is not None:
|
| 68 |
-
self._extract_embed_fn = extract_embed_fn
|
| 69 |
-
else:
|
| 70 |
-
self._extract_embed_fn = partial(get_full_embed, device=device)
|
| 71 |
-
if self.cache_path is not None:
|
| 72 |
-
self.cache_path.mkdir(exist_ok=True, parents=True)
|
| 73 |
-
logger.info(f"Cache instantiated at: {self.cache_path}")
|
| 74 |
-
self.pool = ThreadPoolExecutor(8)
|
| 75 |
-
self.pool.__enter__()
|
| 76 |
-
self._current_batch_cache: dict = {}
|
| 77 |
-
self._memory_cache: dict = {}
|
| 78 |
-
|
| 79 |
-
def _get_cache_path(self, path: tp.Union[Path, str]):
|
| 80 |
-
"""Get cache path for the given file path."""
|
| 81 |
-
sig = sha1(str(path).encode()).hexdigest()
|
| 82 |
-
return self.cache_path / sig
|
| 83 |
-
|
| 84 |
-
@staticmethod
|
| 85 |
-
def _get_full_embed_from_cache(cache: Path):
|
| 86 |
-
"""Loads full pre-computed embedding from the cache."""
|
| 87 |
-
try:
|
| 88 |
-
embed = torch.load(cache, 'cpu')
|
| 89 |
-
except Exception as exc:
|
| 90 |
-
logger.error("Error loading %s: %r", cache, exc)
|
| 91 |
-
embed = None
|
| 92 |
-
return embed
|
| 93 |
-
|
| 94 |
-
def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor:
|
| 95 |
-
"""Get embedding from cache, computing and storing it to cache if not already cached.
|
| 96 |
-
The EmbeddingCache first tries to load the embedding from the in-memory cache
|
| 97 |
-
containing the pre-computed chunks populated through `populate_embed_cache`.
|
| 98 |
-
If not found, the full embedding is computed and stored on disk to be later accessed
|
| 99 |
-
to populate the in-memory cache, and the desired embedding chunk is extracted and returned.
|
| 100 |
-
|
| 101 |
-
Args:
|
| 102 |
-
paths (list[Path or str]): List of paths from where the embeddings can be loaded.
|
| 103 |
-
x (any): Object from which the embedding is extracted.
|
| 104 |
-
"""
|
| 105 |
-
embeds = []
|
| 106 |
-
for idx, path in enumerate(paths):
|
| 107 |
-
cache = self._get_cache_path(path)
|
| 108 |
-
if cache in self._current_batch_cache:
|
| 109 |
-
embed = self._current_batch_cache[cache]
|
| 110 |
-
else:
|
| 111 |
-
full_embed = self._compute_embed_fn(path, x, idx)
|
| 112 |
-
try:
|
| 113 |
-
with flashy.utils.write_and_rename(cache, pid=True) as f:
|
| 114 |
-
torch.save(full_embed.cpu(), f)
|
| 115 |
-
except Exception as exc:
|
| 116 |
-
logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc)
|
| 117 |
-
else:
|
| 118 |
-
logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape)
|
| 119 |
-
embed = self._extract_embed_fn(full_embed, x, idx)
|
| 120 |
-
embeds.append(embed)
|
| 121 |
-
embed = torch.stack(embeds, dim=0)
|
| 122 |
-
return embed
|
| 123 |
-
|
| 124 |
-
def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
|
| 125 |
-
"""Populate in-memory caches for embeddings reading from the embeddings stored on disk.
|
| 126 |
-
The in-memory caches consist in a cache for the full embedding and another cache for the
|
| 127 |
-
final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings
|
| 128 |
-
and reduce the IO footprint and synchronization points during forward passes.
|
| 129 |
-
|
| 130 |
-
Args:
|
| 131 |
-
paths (list[Path]): List of paths from where the embeddings can be loaded.
|
| 132 |
-
x (any): Object from which the embedding is extracted.
|
| 133 |
-
"""
|
| 134 |
-
self._current_batch_cache.clear()
|
| 135 |
-
if self.cache_path is not None:
|
| 136 |
-
futures: list = []
|
| 137 |
-
for path in paths:
|
| 138 |
-
assert path is not None, "Path is required for computation from cache"
|
| 139 |
-
cache = self._get_cache_path(path)
|
| 140 |
-
if cache in self._memory_cache or not cache.exists():
|
| 141 |
-
futures.append(None)
|
| 142 |
-
else:
|
| 143 |
-
futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache))
|
| 144 |
-
for idx, (path, future) in enumerate(zip(paths, futures)):
|
| 145 |
-
assert path is not None
|
| 146 |
-
cache = self._get_cache_path(path)
|
| 147 |
-
full_embed = None
|
| 148 |
-
if future is None:
|
| 149 |
-
if cache in self._memory_cache:
|
| 150 |
-
full_embed = self._memory_cache[cache]
|
| 151 |
-
else:
|
| 152 |
-
full_embed = future.result()
|
| 153 |
-
if full_embed is not None:
|
| 154 |
-
self._memory_cache[cache] = full_embed
|
| 155 |
-
full_embed = full_embed.to(self.device)
|
| 156 |
-
if full_embed is not None:
|
| 157 |
-
embed = self._extract_embed_fn(full_embed, x, idx)
|
| 158 |
-
self._current_batch_cache[cache] = embed
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
class CachedBatchWriter:
|
| 162 |
-
"""Write pre computed caches for mini batches. This can
|
| 163 |
-
make loading a lot more efficient depending on your filesystem.
|
| 164 |
-
|
| 165 |
-
Args:
|
| 166 |
-
cache_folder (Path): folder in which the cached minibatches
|
| 167 |
-
will be stored.
|
| 168 |
-
|
| 169 |
-
Inside cache folder, the structure is the following:
|
| 170 |
-
`epoch_number / update_number.zip`
|
| 171 |
-
And the zip file contains one entry per batch item.
|
| 172 |
-
|
| 173 |
-
It is possible to use the cache with a batch size smaller than
|
| 174 |
-
created with but obviously not larger. Make sure to call the
|
| 175 |
-
`start_epoch(epoch)` method for indicating changes of epochs.
|
| 176 |
-
|
| 177 |
-
See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py`
|
| 178 |
-
for an example of how to warmup the cache.
|
| 179 |
-
"""
|
| 180 |
-
def __init__(self, cache_folder: Path):
|
| 181 |
-
self.cache_folder = cache_folder
|
| 182 |
-
self._current_epoch: tp.Optional[int] = None
|
| 183 |
-
self._current_index = 0
|
| 184 |
-
|
| 185 |
-
def start_epoch(self, epoch: int):
|
| 186 |
-
"""Call at the beginning of each epoch.
|
| 187 |
-
"""
|
| 188 |
-
self._current_epoch = epoch
|
| 189 |
-
self._current_index = 0
|
| 190 |
-
self._zip_path.parent.mkdir(exist_ok=True, parents=True)
|
| 191 |
-
|
| 192 |
-
@staticmethod
|
| 193 |
-
def _get_zip_path(cache_folder: Path, epoch: int, index: int):
|
| 194 |
-
return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip"
|
| 195 |
-
|
| 196 |
-
@property
|
| 197 |
-
def _zip_path(self):
|
| 198 |
-
assert self._current_epoch is not None
|
| 199 |
-
return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index)
|
| 200 |
-
|
| 201 |
-
def save(self, *content):
|
| 202 |
-
"""Save one mini batch. This function is distributed-aware
|
| 203 |
-
and will automatically merge all the items from the different
|
| 204 |
-
workers.
|
| 205 |
-
"""
|
| 206 |
-
all_contents = []
|
| 207 |
-
for rank in range(flashy.distrib.world_size()):
|
| 208 |
-
their_content = flashy.distrib.broadcast_object(content, src=rank)
|
| 209 |
-
all_contents.append(their_content)
|
| 210 |
-
|
| 211 |
-
if flashy.distrib.is_rank_zero():
|
| 212 |
-
idx = 0
|
| 213 |
-
with flashy.utils.write_and_rename(self._zip_path) as tmp:
|
| 214 |
-
with zipfile.ZipFile(tmp, 'w') as zf:
|
| 215 |
-
for content in all_contents:
|
| 216 |
-
for vals in zip(*content):
|
| 217 |
-
with zf.open(f'{idx}', 'w') as f: # type: ignore
|
| 218 |
-
torch.save(vals, f)
|
| 219 |
-
idx += 1
|
| 220 |
-
flashy.distrib.barrier()
|
| 221 |
-
self._current_index += 1
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
class CachedBatchLoader:
|
| 225 |
-
"""Loader for cached mini-batches dumped with `CachedBatchWriter`.
|
| 226 |
-
|
| 227 |
-
Args:
|
| 228 |
-
cache_folder (Path): folder in which the cached minibatches are stored.
|
| 229 |
-
batch_size (int): batch size (per GPU) expected.
|
| 230 |
-
num_workers (int): number of workers to use for loading.
|
| 231 |
-
min_length (int): minimum expected length for each epoch. If some
|
| 232 |
-
mini-batches are missing, and error is raised.
|
| 233 |
-
|
| 234 |
-
This is iterable just like a regular DataLoader.
|
| 235 |
-
"""
|
| 236 |
-
|
| 237 |
-
def __init__(self, cache_folder: Path, batch_size: int,
|
| 238 |
-
num_workers: int = 10, min_length: int = 1):
|
| 239 |
-
self.cache_folder = cache_folder
|
| 240 |
-
self.batch_size = batch_size
|
| 241 |
-
self.num_workers = num_workers
|
| 242 |
-
self.min_length = min_length
|
| 243 |
-
self._current_epoch: tp.Optional[int] = None
|
| 244 |
-
self.sampler = None # for compatibility with the regular DataLoader
|
| 245 |
-
|
| 246 |
-
def __len__(self):
|
| 247 |
-
path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent
|
| 248 |
-
return len([p for p in path.iterdir() if p.suffix == ".zip"])
|
| 249 |
-
|
| 250 |
-
def start_epoch(self, epoch: int):
|
| 251 |
-
"""Call at the beginning of each epoch.
|
| 252 |
-
"""
|
| 253 |
-
self._current_epoch = epoch
|
| 254 |
-
|
| 255 |
-
def _zip_path(self, index: int):
|
| 256 |
-
assert self._current_epoch is not None
|
| 257 |
-
return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index)
|
| 258 |
-
|
| 259 |
-
def _load_one(self, index: int):
|
| 260 |
-
zip_path = self._zip_path(index)
|
| 261 |
-
if not zip_path.exists():
|
| 262 |
-
if index < self.min_length:
|
| 263 |
-
raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist")
|
| 264 |
-
|
| 265 |
-
return None
|
| 266 |
-
mode = "rb" if sys.version_info >= (3, 9) else "r"
|
| 267 |
-
try:
|
| 268 |
-
with zipfile.ZipFile(zip_path, 'r') as zf:
|
| 269 |
-
rank = flashy.distrib.rank()
|
| 270 |
-
world_size = flashy.distrib.world_size()
|
| 271 |
-
root = zipfile.Path(zf)
|
| 272 |
-
items = list(root.iterdir())
|
| 273 |
-
total_batch_size = self.batch_size * world_size
|
| 274 |
-
if len(items) < total_batch_size:
|
| 275 |
-
raise RuntimeError(
|
| 276 |
-
f"The cache can handle a max batch size of {len(items)}, "
|
| 277 |
-
f"but {total_batch_size} is needed.")
|
| 278 |
-
start = rank * self.batch_size
|
| 279 |
-
items = items[start: start + self.batch_size]
|
| 280 |
-
assert len(items) == self.batch_size
|
| 281 |
-
entries = []
|
| 282 |
-
entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore
|
| 283 |
-
transposed = zip(*entries)
|
| 284 |
-
out = []
|
| 285 |
-
for part in transposed:
|
| 286 |
-
assert len(part) > 0
|
| 287 |
-
if isinstance(part[0], torch.Tensor):
|
| 288 |
-
out.append(torch.stack(part))
|
| 289 |
-
else:
|
| 290 |
-
assert isinstance(part, torch.Tensor)
|
| 291 |
-
out.append(part)
|
| 292 |
-
return out
|
| 293 |
-
except Exception:
|
| 294 |
-
logger.error("Error when reading zip path %s", zip_path)
|
| 295 |
-
raise
|
| 296 |
-
|
| 297 |
-
def __iter__(self):
|
| 298 |
-
"""This will yields tuples, exactly as provided to the
|
| 299 |
-
`CachedBatchWriter.save` method.
|
| 300 |
-
"""
|
| 301 |
-
pool = ThreadPoolExecutor(self.num_workers)
|
| 302 |
-
next_index = 0
|
| 303 |
-
queue = deque()
|
| 304 |
-
|
| 305 |
-
def _get_next():
|
| 306 |
-
nonlocal next_index
|
| 307 |
-
r = queue.popleft().result()
|
| 308 |
-
if r is None:
|
| 309 |
-
return None
|
| 310 |
-
else:
|
| 311 |
-
queue.append(pool.submit(self._load_one, next_index))
|
| 312 |
-
next_index += 1
|
| 313 |
-
return r
|
| 314 |
-
|
| 315 |
-
with pool:
|
| 316 |
-
# fill the buffer of fetching jobs.
|
| 317 |
-
for _ in range(2 * self.num_workers):
|
| 318 |
-
queue.append(pool.submit(self._load_one, next_index))
|
| 319 |
-
next_index += 1
|
| 320 |
-
while True:
|
| 321 |
-
batch = _get_next()
|
| 322 |
-
if batch is None:
|
| 323 |
-
return
|
| 324 |
-
yield batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/cluster.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Utility functions for SLURM configuration and cluster settings.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
from enum import Enum
|
| 12 |
-
import os
|
| 13 |
-
import socket
|
| 14 |
-
import typing as tp
|
| 15 |
-
|
| 16 |
-
import omegaconf
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class ClusterType(Enum):
|
| 20 |
-
AWS = "aws"
|
| 21 |
-
FAIR = "fair"
|
| 22 |
-
RSC = "rsc"
|
| 23 |
-
LOCAL_DARWIN = "darwin"
|
| 24 |
-
DEFAULT = "default" # used for any other cluster.
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def _guess_cluster_type() -> ClusterType:
|
| 28 |
-
uname = os.uname()
|
| 29 |
-
fqdn = socket.getfqdn()
|
| 30 |
-
if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn):
|
| 31 |
-
return ClusterType.AWS
|
| 32 |
-
|
| 33 |
-
if fqdn.endswith(".fair"):
|
| 34 |
-
return ClusterType.FAIR
|
| 35 |
-
|
| 36 |
-
if fqdn.endswith(".facebook.com"):
|
| 37 |
-
return ClusterType.RSC
|
| 38 |
-
|
| 39 |
-
if uname.sysname == "Darwin":
|
| 40 |
-
return ClusterType.LOCAL_DARWIN
|
| 41 |
-
|
| 42 |
-
return ClusterType.DEFAULT
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def get_cluster_type(
|
| 46 |
-
cluster_type: tp.Optional[ClusterType] = None,
|
| 47 |
-
) -> tp.Optional[ClusterType]:
|
| 48 |
-
if cluster_type is None:
|
| 49 |
-
return _guess_cluster_type()
|
| 50 |
-
|
| 51 |
-
return cluster_type
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def get_slurm_parameters(
|
| 55 |
-
cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
|
| 56 |
-
) -> omegaconf.DictConfig:
|
| 57 |
-
"""Update SLURM parameters in configuration based on cluster type.
|
| 58 |
-
If the cluster type is not specify, it infers it automatically.
|
| 59 |
-
"""
|
| 60 |
-
from ..environment import AudioCraftEnvironment
|
| 61 |
-
cluster_type = get_cluster_type(cluster_type)
|
| 62 |
-
# apply cluster-specific adjustments
|
| 63 |
-
if cluster_type == ClusterType.AWS:
|
| 64 |
-
cfg["mem_per_gpu"] = None
|
| 65 |
-
cfg["constraint"] = None
|
| 66 |
-
cfg["setup"] = []
|
| 67 |
-
elif cluster_type == ClusterType.RSC:
|
| 68 |
-
cfg["mem_per_gpu"] = None
|
| 69 |
-
cfg["setup"] = []
|
| 70 |
-
cfg["constraint"] = None
|
| 71 |
-
cfg["partition"] = "learn"
|
| 72 |
-
slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
|
| 73 |
-
if slurm_exclude is not None:
|
| 74 |
-
cfg["exclude"] = slurm_exclude
|
| 75 |
-
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/export.py
DELETED
|
@@ -1,79 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Utility to export a training checkpoint to a lightweight release checkpoint.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
import typing as tp
|
| 13 |
-
|
| 14 |
-
from omegaconf import OmegaConf
|
| 15 |
-
import torch
|
| 16 |
-
|
| 17 |
-
from audiocraft import __version__
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
|
| 21 |
-
"""Export only the best state from the given EnCodec checkpoint. This
|
| 22 |
-
should be used if you trained your own EnCodec model.
|
| 23 |
-
"""
|
| 24 |
-
pkg = torch.load(checkpoint_path, 'cpu')
|
| 25 |
-
new_pkg = {
|
| 26 |
-
'best_state': pkg['best_state']['model'],
|
| 27 |
-
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
| 28 |
-
'version': __version__,
|
| 29 |
-
'exported': True,
|
| 30 |
-
}
|
| 31 |
-
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
| 32 |
-
torch.save(new_pkg, out_file)
|
| 33 |
-
return out_file
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
|
| 37 |
-
"""Export a compression model (potentially EnCodec) from a pretrained model.
|
| 38 |
-
This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
|
| 39 |
-
Do not include the //pretrained/ prefix. For instance if you trained a model
|
| 40 |
-
with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.
|
| 41 |
-
|
| 42 |
-
In that case, this will not actually include a copy of the model, simply the reference
|
| 43 |
-
to the model used.
|
| 44 |
-
"""
|
| 45 |
-
if Path(pretrained_encodec).exists():
|
| 46 |
-
pkg = torch.load(pretrained_encodec)
|
| 47 |
-
assert 'best_state' in pkg
|
| 48 |
-
assert 'xp.cfg' in pkg
|
| 49 |
-
assert 'version' in pkg
|
| 50 |
-
assert 'exported' in pkg
|
| 51 |
-
else:
|
| 52 |
-
pkg = {
|
| 53 |
-
'pretrained': pretrained_encodec,
|
| 54 |
-
'exported': True,
|
| 55 |
-
'version': __version__,
|
| 56 |
-
}
|
| 57 |
-
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
| 58 |
-
torch.save(pkg, out_file)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
|
| 62 |
-
"""Export only the best state from the given MusicGen or AudioGen checkpoint.
|
| 63 |
-
"""
|
| 64 |
-
pkg = torch.load(checkpoint_path, 'cpu')
|
| 65 |
-
if pkg['fsdp_best_state']:
|
| 66 |
-
best_state = pkg['fsdp_best_state']['model']
|
| 67 |
-
else:
|
| 68 |
-
assert pkg['best_state']
|
| 69 |
-
best_state = pkg['best_state']['model']
|
| 70 |
-
new_pkg = {
|
| 71 |
-
'best_state': best_state,
|
| 72 |
-
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
| 73 |
-
'version': __version__,
|
| 74 |
-
'exported': True,
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
| 78 |
-
torch.save(new_pkg, out_file)
|
| 79 |
-
return out_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/export_legacy.py
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""
|
| 8 |
-
Utility to export a training checkpoint to a lightweight release checkpoint.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
import typing as tp
|
| 13 |
-
|
| 14 |
-
from omegaconf import OmegaConf, DictConfig
|
| 15 |
-
import torch
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def _clean_lm_cfg(cfg: DictConfig):
|
| 19 |
-
OmegaConf.set_struct(cfg, False)
|
| 20 |
-
# This used to be set automatically in the LM solver, need a more robust solution
|
| 21 |
-
# for the future.
|
| 22 |
-
cfg['transformer_lm']['card'] = 2048
|
| 23 |
-
cfg['transformer_lm']['n_q'] = 4
|
| 24 |
-
# Experimental params no longer supported.
|
| 25 |
-
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
|
| 26 |
-
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
|
| 27 |
-
for name in bad_params:
|
| 28 |
-
del cfg['transformer_lm'][name]
|
| 29 |
-
OmegaConf.set_struct(cfg, True)
|
| 30 |
-
return cfg
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
|
| 34 |
-
sig = Path(checkpoint_path).parent.name
|
| 35 |
-
assert len(sig) == 8, "Not a valid Dora signature"
|
| 36 |
-
pkg = torch.load(checkpoint_path, 'cpu')
|
| 37 |
-
new_pkg = {
|
| 38 |
-
'best_state': pkg['ema']['state']['model'],
|
| 39 |
-
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
| 40 |
-
}
|
| 41 |
-
out_file = Path(out_folder) / f'{sig}.th'
|
| 42 |
-
torch.save(new_pkg, out_file)
|
| 43 |
-
return out_file
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
|
| 47 |
-
sig = Path(checkpoint_path).parent.name
|
| 48 |
-
assert len(sig) == 8, "Not a valid Dora signature"
|
| 49 |
-
pkg = torch.load(checkpoint_path, 'cpu')
|
| 50 |
-
new_pkg = {
|
| 51 |
-
'best_state': pkg['fsdp_best_state']['model'],
|
| 52 |
-
'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
|
| 53 |
-
}
|
| 54 |
-
out_file = Path(out_folder) / f'{sig}.th'
|
| 55 |
-
torch.save(new_pkg, out_file)
|
| 56 |
-
return out_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/extend.py
DELETED
|
@@ -1,440 +0,0 @@
|
|
| 1 |
-
from tabnanny import verbose
|
| 2 |
-
import torch
|
| 3 |
-
import math
|
| 4 |
-
from audiocraft.models import MusicGen
|
| 5 |
-
import numpy as np
|
| 6 |
-
from PIL import Image, ImageDraw, ImageFont, ImageColor
|
| 7 |
-
import string
|
| 8 |
-
import tempfile
|
| 9 |
-
import os
|
| 10 |
-
import textwrap
|
| 11 |
-
import requests
|
| 12 |
-
from io import BytesIO
|
| 13 |
-
from huggingface_hub import hf_hub_download
|
| 14 |
-
import librosa
|
| 15 |
-
import gradio as gr
|
| 16 |
-
import re
|
| 17 |
-
from tqdm import tqdm
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
INTERRUPTING = False
|
| 21 |
-
|
| 22 |
-
def separate_audio_segments(audio, segment_duration=30, overlap=1):
|
| 23 |
-
sr, audio_data = audio[0], audio[1]
|
| 24 |
-
|
| 25 |
-
segment_samples = sr * segment_duration
|
| 26 |
-
total_samples = max(min((len(audio_data) // segment_samples), 25), 0)
|
| 27 |
-
overlap_samples = sr * overlap
|
| 28 |
-
|
| 29 |
-
segments = []
|
| 30 |
-
start_sample = 0
|
| 31 |
-
# handle the case where the audio is shorter than the segment duration
|
| 32 |
-
if total_samples == 0:
|
| 33 |
-
total_samples = 1
|
| 34 |
-
segment_samples = len(audio_data)
|
| 35 |
-
overlap_samples = 0
|
| 36 |
-
while total_samples >= segment_samples:
|
| 37 |
-
# Collect the segment
|
| 38 |
-
# the end sample is the start sample plus the segment samples,
|
| 39 |
-
# the start sample, after 0, is minus the overlap samples to account for the overlap
|
| 40 |
-
end_sample = start_sample + segment_samples
|
| 41 |
-
segment = audio_data[start_sample:end_sample]
|
| 42 |
-
segments.append((sr, segment))
|
| 43 |
-
|
| 44 |
-
start_sample += segment_samples - overlap_samples
|
| 45 |
-
total_samples -= segment_samples
|
| 46 |
-
|
| 47 |
-
# Collect the final segment
|
| 48 |
-
if total_samples > 0:
|
| 49 |
-
segment = audio_data[-segment_samples:]
|
| 50 |
-
segments.append((sr, segment))
|
| 51 |
-
print(f"separate_audio_segments: {len(segments)} segments of length {segment_samples // sr} seconds")
|
| 52 |
-
return segments
|
| 53 |
-
|
| 54 |
-
def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:int=1, segment_duration:int=30, prompt_index:int=0, harmony_only:bool= False, excerpt_duration:float=3.5, progress= gr.Progress(track_tqdm=True)):
|
| 55 |
-
# generate audio segments
|
| 56 |
-
melody_segments = separate_audio_segments(melody, segment_duration, 0)
|
| 57 |
-
|
| 58 |
-
# Create lists to store the melody tensors for each segment
|
| 59 |
-
melodys = []
|
| 60 |
-
output_segments = []
|
| 61 |
-
last_chunk = []
|
| 62 |
-
text += ", seed=" + str(seed)
|
| 63 |
-
prompt_segment = None
|
| 64 |
-
# prevent hacking
|
| 65 |
-
duration = min(duration, 720)
|
| 66 |
-
overlap = min(overlap, 15)
|
| 67 |
-
|
| 68 |
-
# Calculate the total number of segments
|
| 69 |
-
total_segments = max(math.ceil(duration / segment_duration),1)
|
| 70 |
-
#calculate duration loss from segment overlap
|
| 71 |
-
duration_loss = max(total_segments - 1,0) * math.ceil(overlap / 2)
|
| 72 |
-
#calc excess duration
|
| 73 |
-
excess_duration = segment_duration - (total_segments * segment_duration - duration)
|
| 74 |
-
print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration} Overlap Loss {duration_loss}")
|
| 75 |
-
duration += duration_loss
|
| 76 |
-
pbar = tqdm(total=total_segments*2, desc="Generating segments", leave=False)
|
| 77 |
-
while excess_duration + duration_loss > segment_duration:
|
| 78 |
-
total_segments += 1
|
| 79 |
-
#calculate duration loss from segment overlap
|
| 80 |
-
duration_loss += math.ceil(overlap / 2)
|
| 81 |
-
#calc excess duration
|
| 82 |
-
excess_duration = segment_duration - (total_segments * segment_duration - duration)
|
| 83 |
-
print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration} Overlap Loss {duration_loss}")
|
| 84 |
-
if excess_duration + duration_loss > segment_duration:
|
| 85 |
-
duration += duration_loss
|
| 86 |
-
duration_loss = 0
|
| 87 |
-
pbar.update(1)
|
| 88 |
-
total_segments = min(total_segments, (720 // segment_duration))
|
| 89 |
-
|
| 90 |
-
# If melody_segments is shorter than total_segments, repeat the segments until the total_segments is reached
|
| 91 |
-
if len(melody_segments) < total_segments:
|
| 92 |
-
#fix melody_segments
|
| 93 |
-
for i in range(total_segments - len(melody_segments)):
|
| 94 |
-
segment = melody_segments[i]
|
| 95 |
-
melody_segments.append(segment)
|
| 96 |
-
pbar.update(1)
|
| 97 |
-
print(f"melody_segments: {len(melody_segments)} fixed")
|
| 98 |
-
|
| 99 |
-
# Iterate over the segments to create list of Melody tensors
|
| 100 |
-
for segment_idx in range(total_segments):
|
| 101 |
-
if INTERRUPTING:
|
| 102 |
-
return [], duration
|
| 103 |
-
print(f"segment {segment_idx + 1} of {total_segments} \r")
|
| 104 |
-
|
| 105 |
-
if harmony_only:
|
| 106 |
-
# REMOVE PERCUSION FROM MELODY
|
| 107 |
-
# Apply HPSS using librosa
|
| 108 |
-
verse_harmonic, verse_percussive = librosa.effects.hpss(melody_segments[segment_idx][1])
|
| 109 |
-
# Convert the separated components back to torch.Tensor
|
| 110 |
-
#harmonic_tensor = torch.from_numpy(verse_harmonic)
|
| 111 |
-
#percussive_tensor = torch.from_numpy(verse_percussive)
|
| 112 |
-
sr, verse = melody_segments[segment_idx][0], torch.from_numpy(verse_harmonic).to(MODEL.device).float().t().unsqueeze(0)
|
| 113 |
-
else:
|
| 114 |
-
sr, verse = melody_segments[segment_idx][0], torch.from_numpy(melody_segments[segment_idx][1]).to(MODEL.device).float().t().unsqueeze(0)
|
| 115 |
-
|
| 116 |
-
print(f"shape:{verse.shape} dim:{verse.dim()}")
|
| 117 |
-
#if verse is 2D, add 3rd dimension
|
| 118 |
-
if verse.dim() == 2:
|
| 119 |
-
verse = verse[None]
|
| 120 |
-
verse = verse[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
|
| 121 |
-
|
| 122 |
-
# Reduce the length of verse to sr * excerpt_duration
|
| 123 |
-
if ("style" in MODEL.name):
|
| 124 |
-
verse = verse[:, :, :int(sr * excerpt_duration)]
|
| 125 |
-
|
| 126 |
-
# Append the segment to the melodys list
|
| 127 |
-
melodys.append(verse)
|
| 128 |
-
pbar.update(1)
|
| 129 |
-
pbar.close()
|
| 130 |
-
torch.manual_seed(seed)
|
| 131 |
-
|
| 132 |
-
# If user selects a prompt segment, generate a new prompt segment to use on all segments
|
| 133 |
-
#default to the first segment for prompt conditioning
|
| 134 |
-
prompt_verse = melodys[0]
|
| 135 |
-
if prompt_index > 0:
|
| 136 |
-
# Get a prompt segment from the selected verse, normally the first verse
|
| 137 |
-
prompt_verse = melodys[prompt_index if prompt_index <= (total_segments - 1) else (total_segments -1)]
|
| 138 |
-
|
| 139 |
-
# set the prompt segment MODEL generation params
|
| 140 |
-
MODEL.set_generation_params(
|
| 141 |
-
use_sampling=True,
|
| 142 |
-
top_k=MODEL.generation_params["top_k"],
|
| 143 |
-
top_p=MODEL.generation_params["top_p"],
|
| 144 |
-
temperature=MODEL.generation_params["temp"],
|
| 145 |
-
cfg_coef=MODEL.generation_params["cfg_coef"],
|
| 146 |
-
cfg_coef_beta=MODEL.generation_params["cfg_coef_beta"],
|
| 147 |
-
duration=segment_duration,
|
| 148 |
-
two_step_cfg=False,
|
| 149 |
-
rep_penalty=0.5,
|
| 150 |
-
)
|
| 151 |
-
if ("style" in MODEL.name):
|
| 152 |
-
MODEL.set_style_conditioner_params(
|
| 153 |
-
eval_q=MODEL.lm.condition_provider.conditioners.self_wav.eval_q, # integer between 1 and 6
|
| 154 |
-
excerpt_length=excerpt_duration, # the length in seconds that is taken by the model in the provided excerpt, can be between 1.5 and 4.5 seconds but it has to be shortest to the length of the provided conditioning
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
# Generate a new prompt segment. This will be applied to all segments for consistency
|
| 158 |
-
print(f"Generating New Prompt Segment: {text} from verse {prompt_index}\r")
|
| 159 |
-
prompt_segment = MODEL.generate_with_all(
|
| 160 |
-
descriptions=[text],
|
| 161 |
-
melody_wavs=prompt_verse,
|
| 162 |
-
sample_rate=sr,
|
| 163 |
-
progress=False,
|
| 164 |
-
prompt=None,
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
for idx, verse in tqdm(enumerate(melodys), total=len(melodys), desc="Generating melody segments"):
|
| 168 |
-
if INTERRUPTING:
|
| 169 |
-
return output_segments, duration
|
| 170 |
-
|
| 171 |
-
print(f'Segment duration: {segment_duration}, duration: {duration}, overlap: {overlap} Overlap Loss: {duration_loss}')
|
| 172 |
-
# Compensate for the length of final segment
|
| 173 |
-
if ((idx + 1) == len(melodys)) or (duration < segment_duration):
|
| 174 |
-
mod_duration = max(min(duration, segment_duration),1)
|
| 175 |
-
print(f'Modify verse length, duration: {duration}, overlap: {overlap} Overlap Loss: {duration_loss} to mod duration: {mod_duration}')
|
| 176 |
-
MODEL.set_generation_params(
|
| 177 |
-
use_sampling=True,
|
| 178 |
-
top_k=MODEL.generation_params["top_k"],
|
| 179 |
-
top_p=MODEL.generation_params["top_p"],
|
| 180 |
-
temperature=MODEL.generation_params["temp"],
|
| 181 |
-
cfg_coef=MODEL.generation_params["cfg_coef"],
|
| 182 |
-
cfg_coef_beta=MODEL.generation_params["cfg_coef_beta"],
|
| 183 |
-
duration=mod_duration,
|
| 184 |
-
two_step_cfg=False,
|
| 185 |
-
rep_penalty=0.5,
|
| 186 |
-
)
|
| 187 |
-
|
| 188 |
-
if ("style" in MODEL.name):
|
| 189 |
-
MODEL.set_style_conditioner_params(
|
| 190 |
-
eval_q=MODEL.lm.condition_provider.conditioners.self_wav.eval_q, # integer between 1 and 6
|
| 191 |
-
excerpt_length=min(excerpt_duration, mod_duration), # the length in seconds that is taken by the model in the provided excerpt, can be between 1.5 and 4.5 seconds but it has to be shortest to the length of the provided conditioning
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
try:
|
| 195 |
-
# get last chunk
|
| 196 |
-
verse = verse[:, :, -mod_duration*MODEL.sample_rate:]
|
| 197 |
-
prompt_segment = prompt_segment[:, :, -mod_duration*MODEL.sample_rate:]
|
| 198 |
-
except:
|
| 199 |
-
# get first chunk
|
| 200 |
-
verse = verse[:, :, :mod_duration*MODEL.sample_rate]
|
| 201 |
-
prompt_segment = prompt_segment[:, :, :mod_duration*MODEL.sample_rate]
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
print(f"Generating New Melody Segment {idx + 1}: {text}\r")
|
| 205 |
-
output, tokens = MODEL.generate_with_all(
|
| 206 |
-
descriptions=[text],
|
| 207 |
-
melody_wavs=verse,
|
| 208 |
-
sample_rate=sr,
|
| 209 |
-
progress=True,
|
| 210 |
-
prompt=prompt_segment,
|
| 211 |
-
return_tokens = True
|
| 212 |
-
)
|
| 213 |
-
# If user selects a prompt segment, use the prompt segment for all segments
|
| 214 |
-
# Otherwise, use the previous segment as the prompt
|
| 215 |
-
if prompt_index < 0:
|
| 216 |
-
if harmony_only:
|
| 217 |
-
# REMOVE PERCUSION FROM MELODY
|
| 218 |
-
# Apply HPSS using librosa
|
| 219 |
-
verse_harmonic, verse_percussive = librosa.effects.hpss(output.detach().cpu().numpy())
|
| 220 |
-
# Convert the separated components back to torch.Tensor
|
| 221 |
-
#harmonic_tensor = torch.from_numpy(verse_harmonic)
|
| 222 |
-
#percussive_tensor = torch.from_numpy(verse_percussive)
|
| 223 |
-
verse = torch.from_numpy(verse_harmonic).to(MODEL.device).float()
|
| 224 |
-
# if verse is 2D, add extra dimension
|
| 225 |
-
if verse.dim() == 2:
|
| 226 |
-
verse = verse[None]
|
| 227 |
-
output = verse
|
| 228 |
-
prompt_segment = output
|
| 229 |
-
|
| 230 |
-
# Append the generated output to the list of segments
|
| 231 |
-
#output_segments.append(output[:, :segment_duration])
|
| 232 |
-
output_segments.append(output)
|
| 233 |
-
print(f"output_segments: {len(output_segments)}: shape: {output.shape} dim {output.dim()}")
|
| 234 |
-
#track duration
|
| 235 |
-
if duration > segment_duration:
|
| 236 |
-
duration -= segment_duration
|
| 237 |
-
return output_segments, excess_duration
|
| 238 |
-
|
| 239 |
-
def save_image(image):
|
| 240 |
-
"""
|
| 241 |
-
Saves a PIL image to a temporary file and returns the file path.
|
| 242 |
-
|
| 243 |
-
Parameters:
|
| 244 |
-
- image: PIL.Image
|
| 245 |
-
The PIL image object to be saved.
|
| 246 |
-
|
| 247 |
-
Returns:
|
| 248 |
-
- str or None: The file path where the image was saved,
|
| 249 |
-
or None if there was an error saving the image.
|
| 250 |
-
|
| 251 |
-
"""
|
| 252 |
-
temp_dir = tempfile.gettempdir()
|
| 253 |
-
temp_file = tempfile.NamedTemporaryFile(suffix=".png", dir=temp_dir, delete=False)
|
| 254 |
-
temp_file.close()
|
| 255 |
-
file_path = temp_file.name
|
| 256 |
-
|
| 257 |
-
try:
|
| 258 |
-
image.save(file_path)
|
| 259 |
-
|
| 260 |
-
except Exception as e:
|
| 261 |
-
print("Unable to save image:", str(e))
|
| 262 |
-
return None
|
| 263 |
-
finally:
|
| 264 |
-
return file_path
|
| 265 |
-
|
| 266 |
-
def detect_color_format(color):
|
| 267 |
-
"""
|
| 268 |
-
Detects if the color is in RGB, RGBA, or hex format,
|
| 269 |
-
and converts it to an RGBA tuple with integer components.
|
| 270 |
-
|
| 271 |
-
Args:
|
| 272 |
-
color (str or tuple): The color to detect.
|
| 273 |
-
|
| 274 |
-
Returns:
|
| 275 |
-
tuple: The color in RGBA format as a tuple of 4 integers.
|
| 276 |
-
|
| 277 |
-
Raises:
|
| 278 |
-
ValueError: If the input color is not in a recognized format.
|
| 279 |
-
"""
|
| 280 |
-
# Handle color as a tuple of floats or integers
|
| 281 |
-
if isinstance(color, tuple):
|
| 282 |
-
if len(color) == 3 or len(color) == 4:
|
| 283 |
-
# Ensure all components are numbers
|
| 284 |
-
if all(isinstance(c, (int, float)) for c in color):
|
| 285 |
-
r, g, b = color[:3]
|
| 286 |
-
a = color[3] if len(color) == 4 else 255
|
| 287 |
-
return (
|
| 288 |
-
max(0, min(255, int(round(r)))),
|
| 289 |
-
max(0, min(255, int(round(g)))),
|
| 290 |
-
max(0, min(255, int(round(b)))),
|
| 291 |
-
max(0, min(255, int(round(a * 255)) if a <= 1 else round(a))),
|
| 292 |
-
)
|
| 293 |
-
else:
|
| 294 |
-
raise ValueError(f"Invalid color tuple length: {len(color)}")
|
| 295 |
-
# Handle hex color codes
|
| 296 |
-
if isinstance(color, str):
|
| 297 |
-
color = color.strip()
|
| 298 |
-
# Try to use PIL's ImageColor
|
| 299 |
-
try:
|
| 300 |
-
rgba = ImageColor.getcolor(color, "RGBA")
|
| 301 |
-
return rgba
|
| 302 |
-
except ValueError:
|
| 303 |
-
pass
|
| 304 |
-
# Handle 'rgba(r, g, b, a)' string format
|
| 305 |
-
rgba_match = re.match(r'rgba\(\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+)\s*\)', color)
|
| 306 |
-
if rgba_match:
|
| 307 |
-
r, g, b, a = map(float, rgba_match.groups())
|
| 308 |
-
return (
|
| 309 |
-
max(0, min(255, int(round(r)))),
|
| 310 |
-
max(0, min(255, int(round(g)))),
|
| 311 |
-
max(0, min(255, int(round(b)))),
|
| 312 |
-
max(0, min(255, int(round(a * 255)) if a <= 1 else round(a))),
|
| 313 |
-
)
|
| 314 |
-
# Handle 'rgb(r, g, b)' string format
|
| 315 |
-
rgb_match = re.match(r'rgb\(\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+)\s*\)', color)
|
| 316 |
-
if rgb_match:
|
| 317 |
-
r, g, b = map(float, rgb_match.groups())
|
| 318 |
-
return (
|
| 319 |
-
max(0, min(255, int(round(r)))),
|
| 320 |
-
max(0, min(255, int(round(g)))),
|
| 321 |
-
max(0, min(255, int(round(b)))),
|
| 322 |
-
255,
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
-
# If none of the above conversions work, raise an error
|
| 326 |
-
raise ValueError(f"Invalid color format: {color}")
|
| 327 |
-
|
| 328 |
-
def hex_to_rgba(hex_color):
|
| 329 |
-
try:
|
| 330 |
-
if hex_color.startswith("#"):
|
| 331 |
-
clean_hex = hex_color.replace('#','')
|
| 332 |
-
# Use a generator expression to convert pairs of hexadecimal digits to integers and create a tuple
|
| 333 |
-
rgba = tuple(int(clean_hex[i:i+2], 16) for i in range(0, len(clean_hex),2))
|
| 334 |
-
else:
|
| 335 |
-
rgba = tuple(map(int,detect_color_format(hex_color)))
|
| 336 |
-
except ValueError:
|
| 337 |
-
# If the hex color is invalid, default to yellow
|
| 338 |
-
rgba = (255,255,0,255)
|
| 339 |
-
return rgba
|
| 340 |
-
|
| 341 |
-
def load_font(font_name, font_size=16):
|
| 342 |
-
"""
|
| 343 |
-
Load a font using the provided font name and font size.
|
| 344 |
-
|
| 345 |
-
Parameters:
|
| 346 |
-
font_name (str): The name of the font to load. Can be a font name recognized by the system, a URL to download the font file,
|
| 347 |
-
a local file path, or a Hugging Face model hub identifier.
|
| 348 |
-
font_size (int, optional): The size of the font. Default is 16.
|
| 349 |
-
|
| 350 |
-
Returns:
|
| 351 |
-
ImageFont.FreeTypeFont: The loaded font object.
|
| 352 |
-
|
| 353 |
-
Notes:
|
| 354 |
-
This function attempts to load the font using various methods until a suitable font is found. If the provided font_name
|
| 355 |
-
cannot be loaded, it falls back to a default font.
|
| 356 |
-
|
| 357 |
-
The font_name can be one of the following:
|
| 358 |
-
- A font name recognized by the system, which can be loaded using ImageFont.truetype.
|
| 359 |
-
- A URL pointing to the font file, which is downloaded using requests and then loaded using ImageFont.truetype.
|
| 360 |
-
- A local file path to the font file, which is loaded using ImageFont.truetype.
|
| 361 |
-
- A Hugging Face model hub identifier, which downloads the font file from the Hugging Face model hub using hf_hub_download
|
| 362 |
-
and then loads it using ImageFont.truetype.
|
| 363 |
-
|
| 364 |
-
Example:
|
| 365 |
-
font = load_font("Arial.ttf", font_size=20)
|
| 366 |
-
"""
|
| 367 |
-
font = None
|
| 368 |
-
if not "http" in font_name:
|
| 369 |
-
try:
|
| 370 |
-
font = ImageFont.truetype(font_name, font_size)
|
| 371 |
-
except (FileNotFoundError, OSError):
|
| 372 |
-
print("Font not found. Using Hugging Face download..\n")
|
| 373 |
-
|
| 374 |
-
if font is None:
|
| 375 |
-
try:
|
| 376 |
-
font_path = ImageFont.truetype(hf_hub_download(repo_id=os.environ.get('SPACE_ID', ''), filename="assets/" + font_name, repo_type="space"), encoding="UTF-8")
|
| 377 |
-
font = ImageFont.truetype(font_path, font_size)
|
| 378 |
-
except (FileNotFoundError, OSError):
|
| 379 |
-
print("Font not found. Trying to download from local assets folder...\n")
|
| 380 |
-
if font is None:
|
| 381 |
-
try:
|
| 382 |
-
font = ImageFont.truetype("assets/" + font_name, font_size)
|
| 383 |
-
except (FileNotFoundError, OSError):
|
| 384 |
-
print("Font not found. Trying to download from URL...\n")
|
| 385 |
-
|
| 386 |
-
if font is None:
|
| 387 |
-
try:
|
| 388 |
-
req = requests.get(font_name)
|
| 389 |
-
font = ImageFont.truetype(BytesIO(req.content), font_size)
|
| 390 |
-
except (FileNotFoundError, OSError):
|
| 391 |
-
print(f"Font not found: {font_name} Using default font\n")
|
| 392 |
-
|
| 393 |
-
if font:
|
| 394 |
-
print(f"Font loaded {font.getname()}")
|
| 395 |
-
else:
|
| 396 |
-
font = ImageFont.load_default()
|
| 397 |
-
return font
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
def add_settings_to_image(title: str = "title", description: str = "", width: int = 768, height: int = 512, background_path: str = "", font: str = "arial.ttf", font_color: str = "#ffffff", font_size: int = 28, progress=gr.Progress(track_tqdm=True)):
|
| 401 |
-
# Create a new RGBA image with the specified dimensions
|
| 402 |
-
image = Image.new("RGBA", (width, height), (255, 255, 255, 0))
|
| 403 |
-
# If a background image is specified, open it and paste it onto the image
|
| 404 |
-
if background_path == "":
|
| 405 |
-
background = Image.new("RGBA", (width, height), (255, 255, 255, 255))
|
| 406 |
-
else:
|
| 407 |
-
background = Image.open(background_path).convert("RGBA")
|
| 408 |
-
|
| 409 |
-
#Convert font color to RGBA tuple
|
| 410 |
-
font_color = hex_to_rgba(font_color)
|
| 411 |
-
print(f"Font Color: {font_color}\n")
|
| 412 |
-
|
| 413 |
-
# Calculate the center coordinates for placing the text
|
| 414 |
-
text_x = width // 2
|
| 415 |
-
text_y = height // 2
|
| 416 |
-
# Draw the title text at the center top
|
| 417 |
-
title_font = load_font(font, font_size) # Replace with your desired font and size
|
| 418 |
-
|
| 419 |
-
title_text = '\n'.join(textwrap.wrap(title, width // 12))
|
| 420 |
-
title_x, title_y, title_text_width, title_text_height = title_font.getbbox(title_text)
|
| 421 |
-
title_x = max(text_x - (title_text_width // 2), title_x, 0)
|
| 422 |
-
title_y = text_y - (height // 2) + 10 # 10 pixels padding from the top
|
| 423 |
-
title_draw = ImageDraw.Draw(image)
|
| 424 |
-
title_draw.multiline_text((title_x, title_y), title, fill=font_color, font=title_font, align="center")
|
| 425 |
-
# Draw the description text two lines below the title
|
| 426 |
-
description_font = load_font(font, int(font_size * 2 / 3)) # Replace with your desired font and size
|
| 427 |
-
description_text = '\n'.join(textwrap.wrap(description, width // 12))
|
| 428 |
-
description_x, description_y, description_text_width, description_text_height = description_font.getbbox(description_text)
|
| 429 |
-
description_x = max(text_x - (description_text_width // 2), description_x, 0)
|
| 430 |
-
description_y = title_y + title_text_height + 20 # 20 pixels spacing between title and description
|
| 431 |
-
description_draw = ImageDraw.Draw(image)
|
| 432 |
-
description_draw.multiline_text((description_x, description_y), description_text, fill=font_color, font=description_font, align="center")
|
| 433 |
-
# Calculate the offset to center the image on the background
|
| 434 |
-
bg_w, bg_h = background.size
|
| 435 |
-
offset = ((bg_w - width) // 2, (bg_h - height) // 2)
|
| 436 |
-
# Paste the image onto the background
|
| 437 |
-
background.paste(image, offset, mask=image)
|
| 438 |
-
|
| 439 |
-
# Save the image and return the file path
|
| 440 |
-
return save_image(background)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/notebook.py
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
try:
|
| 8 |
-
import IPython.display as ipd # type: ignore
|
| 9 |
-
except ImportError:
|
| 10 |
-
# Note in a notebook...
|
| 11 |
-
pass
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def display_audio(samples: torch.Tensor, sample_rate: int):
|
| 18 |
-
"""Renders an audio player for the given audio samples.
|
| 19 |
-
|
| 20 |
-
Args:
|
| 21 |
-
samples (torch.Tensor): a Tensor of decoded audio samples
|
| 22 |
-
with shapes [B, C, T] or [C, T]
|
| 23 |
-
sample_rate (int): sample rate audio should be displayed with.
|
| 24 |
-
"""
|
| 25 |
-
assert samples.dim() == 2 or samples.dim() == 3
|
| 26 |
-
|
| 27 |
-
samples = samples.detach().cpu()
|
| 28 |
-
if samples.dim() == 2:
|
| 29 |
-
samples = samples[None, ...]
|
| 30 |
-
|
| 31 |
-
for audio in samples:
|
| 32 |
-
ipd.display(ipd.Audio(audio, rate=sample_rate))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/utils.py
DELETED
|
@@ -1,328 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from concurrent.futures import ProcessPoolExecutor
|
| 8 |
-
from contextlib import contextmanager
|
| 9 |
-
from functools import wraps, lru_cache
|
| 10 |
-
import hashlib
|
| 11 |
-
import json
|
| 12 |
-
import logging
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
import typing as tp
|
| 15 |
-
|
| 16 |
-
import flashy
|
| 17 |
-
import flashy.distrib
|
| 18 |
-
import omegaconf
|
| 19 |
-
import torch
|
| 20 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
logger = logging.getLogger(__name__)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def model_hash(model: torch.nn.Module) -> str:
|
| 27 |
-
"""Return a model hash. This should allow us to track regressions in model init
|
| 28 |
-
from the logs of past experiments.
|
| 29 |
-
"""
|
| 30 |
-
hasher = hashlib.sha1()
|
| 31 |
-
for p in model.parameters():
|
| 32 |
-
hasher.update(p.data.cpu().numpy().tobytes())
|
| 33 |
-
return hasher.hexdigest()
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
|
| 39 |
-
"""Convenience function to map an omegaconf configuration to a dictionary.
|
| 40 |
-
|
| 41 |
-
Args:
|
| 42 |
-
cfg (omegaconf.DictConfig): Original configuration to map to dict.
|
| 43 |
-
Returns:
|
| 44 |
-
dict: Config as dictionary object.
|
| 45 |
-
"""
|
| 46 |
-
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
| 47 |
-
assert isinstance(dct, dict)
|
| 48 |
-
return dct
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
|
| 52 |
-
if max_samples >= len(dataset):
|
| 53 |
-
return dataset
|
| 54 |
-
|
| 55 |
-
generator = torch.Generator().manual_seed(seed)
|
| 56 |
-
perm = torch.randperm(len(dataset), generator=generator)
|
| 57 |
-
return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
|
| 61 |
-
num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
|
| 62 |
-
"""Convenience function to load dataset into a dataloader with optional subset sampling.
|
| 63 |
-
|
| 64 |
-
Args:
|
| 65 |
-
dataset: Dataset to load.
|
| 66 |
-
num_samples (Optional[int]): Number of samples to limit subset size.
|
| 67 |
-
batch_size (int): Batch size.
|
| 68 |
-
num_workers (int): Number of workers for data loading.
|
| 69 |
-
seed (int): Random seed.
|
| 70 |
-
"""
|
| 71 |
-
if num_samples is not None:
|
| 72 |
-
dataset = random_subset(dataset, num_samples, seed)
|
| 73 |
-
|
| 74 |
-
dataloader = flashy.distrib.loader(
|
| 75 |
-
dataset,
|
| 76 |
-
batch_size=batch_size,
|
| 77 |
-
num_workers=num_workers,
|
| 78 |
-
**kwargs
|
| 79 |
-
)
|
| 80 |
-
return dataloader
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def get_dataset_from_loader(dataloader):
|
| 84 |
-
dataset = dataloader.dataset
|
| 85 |
-
if isinstance(dataset, torch.utils.data.Subset):
|
| 86 |
-
return dataset.dataset
|
| 87 |
-
else:
|
| 88 |
-
return dataset
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
| 92 |
-
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
| 93 |
-
|
| 94 |
-
Args:
|
| 95 |
-
input (torch.Tensor): The input tensor containing probabilities.
|
| 96 |
-
num_samples (int): Number of samples to draw.
|
| 97 |
-
replacement (bool): Whether to draw with replacement or not.
|
| 98 |
-
Keywords args:
|
| 99 |
-
generator (torch.Generator): A pseudorandom number generator for sampling.
|
| 100 |
-
Returns:
|
| 101 |
-
torch.Tensor: Last dimension contains num_samples indices
|
| 102 |
-
sampled from the multinomial probability distribution
|
| 103 |
-
located in the last dimension of tensor input.
|
| 104 |
-
"""
|
| 105 |
-
input_ = input.reshape(-1, input.shape[-1])
|
| 106 |
-
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
| 107 |
-
output = output_.reshape(*list(input.shape[:-1]), -1)
|
| 108 |
-
return output
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
| 112 |
-
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
| 113 |
-
|
| 114 |
-
Args:
|
| 115 |
-
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
| 116 |
-
k (int): The k in “top-k”.
|
| 117 |
-
Returns:
|
| 118 |
-
torch.Tensor: Sampled tokens.
|
| 119 |
-
"""
|
| 120 |
-
top_k_value, _ = torch.topk(probs, k, dim=-1)
|
| 121 |
-
min_value_top_k = top_k_value[..., [-1]]
|
| 122 |
-
probs *= (probs >= min_value_top_k).float()
|
| 123 |
-
probs.div_(probs.sum(dim=-1, keepdim=True))
|
| 124 |
-
next_token = multinomial(probs, num_samples=1)
|
| 125 |
-
return next_token
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
| 129 |
-
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
| 130 |
-
|
| 131 |
-
Args:
|
| 132 |
-
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
| 133 |
-
p (int): The p in “top-p”.
|
| 134 |
-
Returns:
|
| 135 |
-
torch.Tensor: Sampled tokens.
|
| 136 |
-
"""
|
| 137 |
-
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
| 138 |
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
| 139 |
-
mask = probs_sum - probs_sort > p
|
| 140 |
-
probs_sort *= (~mask).float()
|
| 141 |
-
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
| 142 |
-
next_token = multinomial(probs_sort, num_samples=1)
|
| 143 |
-
next_token = torch.gather(probs_idx, -1, next_token)
|
| 144 |
-
return next_token
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
class DummyPoolExecutor:
|
| 148 |
-
"""Dummy pool executor to use when we actually have only 1 worker.
|
| 149 |
-
(e.g. instead of ProcessPoolExecutor).
|
| 150 |
-
"""
|
| 151 |
-
class DummyResult:
|
| 152 |
-
def __init__(self, func, *args, **kwargs):
|
| 153 |
-
self.func = func
|
| 154 |
-
self.args = args
|
| 155 |
-
self.kwargs = kwargs
|
| 156 |
-
|
| 157 |
-
def result(self):
|
| 158 |
-
return self.func(*self.args, **self.kwargs)
|
| 159 |
-
|
| 160 |
-
def __init__(self, workers, mp_context=None):
|
| 161 |
-
pass
|
| 162 |
-
|
| 163 |
-
def submit(self, func, *args, **kwargs):
|
| 164 |
-
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
| 165 |
-
|
| 166 |
-
def __enter__(self):
|
| 167 |
-
return self
|
| 168 |
-
|
| 169 |
-
def __exit__(self, exc_type, exc_value, exc_tb):
|
| 170 |
-
return
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
def get_pool_executor(num_workers: int, mp_context=None):
|
| 174 |
-
return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
|
| 178 |
-
"""Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
|
| 179 |
-
For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
|
| 180 |
-
|
| 181 |
-
Args:
|
| 182 |
-
lengths (torch.Tensor): tensor with lengths
|
| 183 |
-
max_len (int): can set the max length manually. Defaults to None.
|
| 184 |
-
Returns:
|
| 185 |
-
torch.Tensor: mask with 0s where there is pad tokens else 1s
|
| 186 |
-
"""
|
| 187 |
-
assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
|
| 188 |
-
final_length = lengths.max().item() if not max_len else max_len
|
| 189 |
-
final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
|
| 190 |
-
return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None]
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
def hash_trick(word: str, vocab_size: int) -> int:
|
| 194 |
-
"""Hash trick to pair each word with an index
|
| 195 |
-
|
| 196 |
-
Args:
|
| 197 |
-
word (str): word we wish to convert to an index
|
| 198 |
-
vocab_size (int): size of the vocabulary
|
| 199 |
-
Returns:
|
| 200 |
-
int: index of the word in the embedding LUT
|
| 201 |
-
"""
|
| 202 |
-
hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
|
| 203 |
-
return hash % vocab_size
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
def with_rank_rng(base_seed: int = 1234):
|
| 207 |
-
"""Decorator for a function so that the function will use a Random Number Generator
|
| 208 |
-
whose state depend on the GPU rank. The original RNG state is restored upon returning.
|
| 209 |
-
|
| 210 |
-
Args:
|
| 211 |
-
base_seed (int): Random seed.
|
| 212 |
-
"""
|
| 213 |
-
def _decorator(fun: tp.Callable):
|
| 214 |
-
@wraps(fun)
|
| 215 |
-
def _decorated(*args, **kwargs):
|
| 216 |
-
state = torch.get_rng_state()
|
| 217 |
-
seed = base_seed ^ flashy.distrib.rank()
|
| 218 |
-
torch.manual_seed(seed)
|
| 219 |
-
logger.debug('Rank dependent seed set to %d', seed)
|
| 220 |
-
try:
|
| 221 |
-
return fun(*args, **kwargs)
|
| 222 |
-
finally:
|
| 223 |
-
torch.set_rng_state(state)
|
| 224 |
-
logger.debug('RNG state restored.')
|
| 225 |
-
return _decorated
|
| 226 |
-
return _decorator
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 230 |
-
"""Get a list of tensors and collate them to a single tensor. according to the following logic:
|
| 231 |
-
- `dim` specifies the time dimension which will be stacked and padded.
|
| 232 |
-
- The output will contain 1 new dimension (dimension index 0) which will be the size of
|
| 233 |
-
of the original list.
|
| 234 |
-
|
| 235 |
-
Args:
|
| 236 |
-
tensors (tp.List[torch.Tensor]): List of tensors to collate.
|
| 237 |
-
dim (int): Dimension which will be stacked and padded.
|
| 238 |
-
Returns:
|
| 239 |
-
tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 240 |
-
torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
|
| 241 |
-
(dimension index 0) which will be the size of the original list.
|
| 242 |
-
torch.Tensor: Tensor containing length of original tensor sizes (without padding).
|
| 243 |
-
"""
|
| 244 |
-
tensors = [x.transpose(0, dim) for x in tensors]
|
| 245 |
-
lens = torch.LongTensor([len(x) for x in tensors])
|
| 246 |
-
padded_tensors = pad_sequence(tensors)
|
| 247 |
-
padded_tensors = padded_tensors.transpose(0, 1)
|
| 248 |
-
padded_tensors = padded_tensors.transpose(1, dim + 1)
|
| 249 |
-
return padded_tensors, lens
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
# TODO: Move to flashy?
|
| 253 |
-
def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
|
| 254 |
-
dtype: tp.Optional[torch.dtype] = None) -> tp.Any:
|
| 255 |
-
if isinstance(state, torch.Tensor):
|
| 256 |
-
if dtype is None or not state.is_floating_point():
|
| 257 |
-
dtype = state.dtype
|
| 258 |
-
return state.detach().to(device=device, dtype=dtype, copy=True)
|
| 259 |
-
elif isinstance(state, dict):
|
| 260 |
-
return {k: copy_state(v, device, dtype) for k, v in state.items()}
|
| 261 |
-
elif isinstance(state, list):
|
| 262 |
-
return [copy_state(v, device, dtype) for v in state]
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
# TODO: Move to flashy?
|
| 266 |
-
@contextmanager
|
| 267 |
-
def swap_state(model, state, **kwargs):
|
| 268 |
-
old_state = copy_state(model.state_dict())
|
| 269 |
-
model.load_state_dict(state, **kwargs)
|
| 270 |
-
try:
|
| 271 |
-
yield
|
| 272 |
-
finally:
|
| 273 |
-
model.load_state_dict(old_state)
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
@lru_cache(None)
|
| 277 |
-
def warn_once(logger, msg):
|
| 278 |
-
"""Warn about a given message only once."""
|
| 279 |
-
logger.warning(msg)
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
def is_jsonable(x: tp.Any):
|
| 283 |
-
"""Check if an object can be serialized into a json:"""
|
| 284 |
-
try:
|
| 285 |
-
json.dumps(x)
|
| 286 |
-
return True
|
| 287 |
-
except (TypeError, OverflowError):
|
| 288 |
-
return False
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
|
| 292 |
-
"""Wrapper around state dict loading of CLAP model
|
| 293 |
-
addressing compatibility issues between CLAP and AudioCraft
|
| 294 |
-
HuggingFace transformer version.
|
| 295 |
-
See: https://github.com/LAION-AI/CLAP/issues/118
|
| 296 |
-
"""
|
| 297 |
-
from clap_module.factory import load_state_dict # type: ignore
|
| 298 |
-
pkg = load_state_dict(path)
|
| 299 |
-
pkg.pop('text_branch.embeddings.position_ids', None)
|
| 300 |
-
clap_model.model.load_state_dict(pkg)
|
| 301 |
-
|
| 302 |
-
def construct_frame_chords(
|
| 303 |
-
min_timestamp: int,
|
| 304 |
-
chord_changes: tp.List[tp.Tuple[float, str]],
|
| 305 |
-
mapping_dict: tp.Dict,
|
| 306 |
-
prev_chord: str,
|
| 307 |
-
frame_rate: float,
|
| 308 |
-
segment_duration: float,
|
| 309 |
-
) -> tp.List[str]:
|
| 310 |
-
""" Translate symbolic chords [(start_time, tuples),...] into a frame-level int sequence"""
|
| 311 |
-
|
| 312 |
-
frames = [
|
| 313 |
-
frame / frame_rate
|
| 314 |
-
for frame in range(
|
| 315 |
-
min_timestamp, int(min_timestamp + segment_duration * frame_rate)
|
| 316 |
-
)
|
| 317 |
-
]
|
| 318 |
-
|
| 319 |
-
frame_chords = []
|
| 320 |
-
current_chord = prev_chord
|
| 321 |
-
|
| 322 |
-
for frame in frames:
|
| 323 |
-
while chord_changes and frame >= chord_changes[0][0]:
|
| 324 |
-
current_chord = chord_changes.pop(0)[1]
|
| 325 |
-
current_chord = 'N' if current_chord in {None, ''} else current_chord
|
| 326 |
-
frame_chords.append(mapping_dict[current_chord])
|
| 327 |
-
|
| 328 |
-
return frame_chords
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/constants.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modules/constants.py
|
| 2 |
+
# constants.py contains all the constants used in the project such as the default LUT example image, prompts, negative prompts, pre-rendered maps, models, LoRA weights, and more.
|
| 3 |
+
# execptions made for some environmental variables
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
IS_SHARED_SPACE = "Agents-MCP-Hackathon/UnlimitedMusicGen" in os.environ.get('SPACE_ID', '')
|
| 12 |
+
|
| 13 |
+
# Load environment variables from .env file
|
| 14 |
+
dotenv_path = Path(__file__).parent.parent / '.env'
|
| 15 |
+
load_dotenv(dotenv_path)
|
| 16 |
+
|
| 17 |
+
# Function to load env vars from .env and create Python variables
|
| 18 |
+
def load_env_vars(env_path):
|
| 19 |
+
try:
|
| 20 |
+
with open(env_path, 'r') as file:
|
| 21 |
+
for line in file:
|
| 22 |
+
# Skip empty lines or comments
|
| 23 |
+
line = line.strip()
|
| 24 |
+
if line and not line.startswith('#'):
|
| 25 |
+
# Split on the first '=' only
|
| 26 |
+
if '=' in line:
|
| 27 |
+
key, value = line.split('=', 1)
|
| 28 |
+
key = key.strip()
|
| 29 |
+
value = value.strip()
|
| 30 |
+
# Dynamically create a Python variable with the key name
|
| 31 |
+
globals()[key] = value
|
| 32 |
+
# Also update os.environ (optional, for consistency)
|
| 33 |
+
os.environ[key] = value
|
| 34 |
+
except FileNotFoundError:
|
| 35 |
+
print(f"Warning: .env file not found at {env_path}")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
USE_FLASH_ATTENTION = os.getenv("USE_FLASH_ATTENTION", "0") == "1"
|
| 40 |
+
HF_API_TOKEN = os.getenv("HF_TOKEN")
|
| 41 |
+
if not HF_API_TOKEN:
|
| 42 |
+
raise ValueError("HF_TOKEN is not set. Please check your .env file.")
|
| 43 |
+
|
| 44 |
+
default_lut_example_img = "./LUT/daisy.jpg"
|
| 45 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 46 |
+
TARGET_SIZE = (2688,1536)
|
| 47 |
+
BASE_HEIGHT = 640
|
| 48 |
+
SCALE_FACTOR = (12/5)
|
| 49 |
+
TMPDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
| 50 |
+
os.makedirs(TMPDIR, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Constants for URL shortener
|
| 54 |
+
HF_REPO_ID = "Surn/Storage" # Or your desired repository
|
| 55 |
+
SHORTENER_JSON_FILE = "shortener.json"
|
| 56 |
+
|
| 57 |
+
model_extensions = {".glb", ".gltf", ".obj", ".ply"}
|
| 58 |
+
model_extensions_list = list(model_extensions)
|
| 59 |
+
image_extensions = {".png", ".jpg", ".jpeg", ".webp"}
|
| 60 |
+
image_extensions_list = list(image_extensions)
|
| 61 |
+
music_extensions = {".mp3", ".wav", ".ogg", ".flac"}
|
| 62 |
+
music_extensions_list = list(music_extensions)
|
| 63 |
+
upload_file_types = model_extensions_list + image_extensions_list + music_extensions_list
|