File size: 72,010 Bytes
bc8c4af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 | from typing import Set, Tuple, Optional, List
from enum import Enum
import math
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer
class AudioProcessor(nn.Module):
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
def __init__(
self,
sample_rate: int = 16000,
mel_bins: int = 64,
mel_hop_length: int = 160,
n_fft: int = 1024,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=n_fft,
hop_length=mel_hop_length,
f_min=0.0,
f_max=sample_rate / 2.0,
n_mels=mel_bins,
window_fn=torch.hann_window,
center=True,
pad_mode="reflect",
power=1.0,
mel_scale="slaney",
norm="slaney",
)
def resample_waveform(
self,
waveform: torch.Tensor,
source_rate: int,
target_rate: int,
) -> torch.Tensor:
"""Resample waveform to target sample rate if needed."""
if source_rate == target_rate:
return waveform
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
return resampled.to(device=waveform.device, dtype=waveform.dtype)
def waveform_to_mel(
self,
waveform: torch.Tensor,
waveform_sample_rate: int,
) -> torch.Tensor:
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate)
mel = self.mel_transform(waveform)
mel = torch.log(torch.clamp(mel, min=1e-5))
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
return mel.permute(0, 1, 3, 2).contiguous()
class AudioPatchifier(Patchifier):
def __init__(
self,
patch_size: int,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
is_causal: bool = True,
shift: int = 0,
):
"""
Patchifier tailored for spectrogram/audio latents.
Args:
patch_size: Number of mel bins combined into a single patch. This
controls the resolution along the frequency axis.
sample_rate: Original waveform sampling rate. Used to map latent
indices back to seconds so downstream consumers can align audio
and video cues.
hop_length: Window hop length used for the spectrogram. Determines
how many real-time samples separate two consecutive latent frames.
audio_latent_downsample_factor: Ratio between spectrogram frames and
latent frames; compensates for additional downsampling inside the
VAE encoder.
is_causal: When True, timing is shifted to account for causal
receptive fields so timestamps do not peek into the future.
shift: Integer offset applied to the latent indices. Enables
constructing overlapping windows from the same latent sequence.
"""
self.hop_length = hop_length
self.sample_rate = sample_rate
self.audio_latent_downsample_factor = audio_latent_downsample_factor
self.is_causal = is_causal
self.shift = shift
self._patch_size = (1, patch_size, patch_size)
@property
def patch_size(self) -> Tuple[int, int, int]:
return self._patch_size
def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
return tgt_shape.frames
def _get_audio_latent_time_in_sec(
self,
start_latent: int,
end_latent: int,
dtype: torch.dtype,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Converts latent indices into real-time seconds while honoring causal
offsets and the configured hop length.
Args:
start_latent: Inclusive start index inside the latent sequence. This
sets the first timestamp returned.
end_latent: Exclusive end index. Determines how many timestamps get
generated.
dtype: Floating-point dtype used for the returned tensor, allowing
callers to control precision.
device: Target device for the timestamp tensor. When omitted the
computation occurs on CPU to avoid surprising GPU allocations.
"""
if device is None:
device = torch.device("cpu")
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
if self.is_causal:
# Frame offset for causal alignment.
# The "+1" ensures the timestamp corresponds to the first sample that is fully available.
causal_offset = 1
audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
return audio_mel_frame * self.hop_length / self.sample_rate
def _compute_audio_timings(
self,
batch_size: int,
num_steps: int,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
Args:
batch_size: Number of sequences to broadcast the timings over.
num_steps: Number of latent frames (time steps) to convert into timestamps.
device: Device on which the resulting tensor should reside.
"""
resolved_device = device
if resolved_device is None:
resolved_device = torch.device("cpu")
start_timings = self._get_audio_latent_time_in_sec(
self.shift,
num_steps + self.shift,
torch.float32,
resolved_device,
)
start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
end_timings = self._get_audio_latent_time_in_sec(
self.shift + 1,
num_steps + self.shift + 1,
torch.float32,
resolved_device,
)
end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
return torch.stack([start_timings, end_timings], dim=-1)
def patchify(
self,
audio_latents: torch.Tensor,
) -> torch.Tensor:
"""
Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
to derive timestamps for each latent frame based on the configured hop
length and downsampling.
Args:
audio_latents: Latent tensor to patchify.
Returns:
Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
corresponding timing metadata when needed.
"""
audio_latents = einops.rearrange(
audio_latents,
"b c t f -> b t (c f)",
)
return audio_latents
def unpatchify(
self,
audio_latents: torch.Tensor,
output_shape: AudioLatentShape,
) -> torch.Tensor:
"""
Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
Use `get_patch_grid_bounds` to recompute the timestamps that describe each
frame's position in real time.
Args:
audio_latents: Latent tensor to unpatchify.
output_shape: Shape of the unpatched output tensor.
Returns:
Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
metadata associated with the restored latents.
"""
# audio_latents shape: (batch, time, freq * channels)
audio_latents = einops.rearrange(
audio_latents,
"b t (c f) -> b c t f",
c=output_shape.channels,
f=output_shape.mel_bins,
)
return audio_latents
def unpatchify_audio(
self,
audio_latents: torch.Tensor,
channels: int,
mel_bins: int
) -> torch.Tensor:
audio_latents = einops.rearrange(
audio_latents,
"b t (c f) -> b c t f",
c=channels,
f=mel_bins,
)
return audio_latents
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Return the temporal bounds `[inclusive start, exclusive end)` for every
patch emitted by `patchify`. For audio this corresponds to timestamps in
seconds aligned with the original spectrogram grid.
The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
- axis 1 (size 1) represents the temporal dimension
- axis 3 (size 2) stores the `[start, end)` timestamps per patch
Args:
output_shape: Audio grid specification describing the number of time steps.
device: Target device for the returned tensor.
"""
if not isinstance(output_shape, AudioLatentShape):
raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)
class AttentionType(Enum):
"""Enum for specifying the attention mechanism type."""
VANILLA = "vanilla"
LINEAR = "linear"
NONE = "none"
class AttnBlock(torch.nn.Module):
def __init__(
self,
in_channels: int,
norm_type: NormType = NormType.GROUP,
) -> None:
super().__init__()
self.in_channels = in_channels
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w).contiguous()
q = q.permute(0, 2, 1).contiguous() # b,hw,c
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w).contiguous()
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w).contiguous()
h_ = self.proj_out(h_)
return x + h_
def make_attn(
in_channels: int,
attn_type: AttentionType = AttentionType.VANILLA,
norm_type: NormType = NormType.GROUP,
) -> torch.nn.Module:
match attn_type:
case AttentionType.VANILLA:
return AttnBlock(in_channels, norm_type=norm_type)
case AttentionType.NONE:
return torch.nn.Identity()
case AttentionType.LINEAR:
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
case _:
raise ValueError(f"Unknown attention type: {attn_type}")
class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"
class CausalConv2d(torch.nn.Module):
"""
A causal 2D convolution.
This layer ensures that the output at time `t` only depends on inputs
at time `t` and earlier. It achieves this by applying asymmetric padding
to the time dimension (width) before the convolution.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
stride: int = 1,
dilation: int | tuple[int, int] = 1,
groups: int = 1,
bias: bool = True,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
) -> None:
super().__init__()
self.causality_axis = causality_axis
# Ensure kernel_size and dilation are tuples
kernel_size = torch.nn.modules.utils._pair(kernel_size)
dilation = torch.nn.modules.utils._pair(dilation)
# Calculate padding dimensions
pad_h = (kernel_size[0] - 1) * dilation[0]
pad_w = (kernel_size[1] - 1) * dilation[1]
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
match self.causality_axis:
case CausalityAxis.NONE:
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
case CausalityAxis.HEIGHT:
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
case _:
raise ValueError(f"Invalid causality_axis: {causality_axis}")
# The internal convolution layer uses no padding, as we handle it manually
self.conv = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Apply causal padding before convolution
x = F.pad(x, self.padding)
return self.conv(x)
def make_conv2d(
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
stride: int = 1,
padding: tuple[int, int, int, int] | None = None,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
causality_axis: CausalityAxis | None = None,
) -> torch.nn.Module:
"""
Create a 2D convolution layer that can be either causal or non-causal.
Args:
in_channels: Number of input channels
out_channels: Number of output channels
kernel_size: Size of the convolution kernel
stride: Convolution stride
padding: Padding (if None, will be calculated based on causal flag)
dilation: Dilation rate
groups: Number of groups for grouped convolution
bias: Whether to use bias
causality_axis: Dimension along which to apply causality.
Returns:
Either a regular Conv2d or CausalConv2d layer
"""
if causality_axis is not None:
# For causal convolution, padding is handled internally by CausalConv2d
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
else:
# For non-causal convolution, use symmetric padding if not specified
if padding is None:
padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
return torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
LRELU_SLOPE = 0.1
class ResBlock1(torch.nn.Module):
def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = torch.nn.ModuleList(
[
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding="same",
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding="same",
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding="same",
),
]
)
self.convs2 = torch.nn.ModuleList(
[
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding="same",
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding="same",
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding="same",
),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv1, conv2 in zip(self.convs1, self.convs2, strict=True):
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
xt = conv1(xt)
xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
xt = conv2(xt)
x = xt + x
return x
class ResBlock2(torch.nn.Module):
def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)):
super(ResBlock2, self).__init__()
self.convs = torch.nn.ModuleList(
[
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding="same",
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding="same",
),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv in self.convs:
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
xt = conv(xt)
x = xt + x
return x
class ResnetBlock(torch.nn.Module):
def __init__(
self,
*,
in_channels: int,
out_channels: int | None = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
norm_type: NormType = NormType.GROUP,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
) -> None:
super().__init__()
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
self.non_linearity = torch.nn.SiLU()
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.nin_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
)
def forward(
self,
x: torch.Tensor,
temb: torch.Tensor | None = None,
) -> torch.Tensor:
h = x
h = self.norm1(h)
h = self.non_linearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
h = self.norm2(h)
h = self.non_linearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
return x + h
class Downsample(torch.nn.Module):
"""
A downsampling layer that can use either a strided convolution
or average pooling. Supports standard and causal padding for the
convolutional mode.
"""
def __init__(
self,
in_channels: int,
with_conv: bool,
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
) -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
raise ValueError("causality is only supported when `with_conv=True`.")
if self.with_conv:
# Do time downsampling here
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.with_conv:
# Padding tuple is in the order: (left, right, top, bottom).
match self.causality_axis:
case CausalityAxis.NONE:
pad = (0, 1, 0, 1)
case CausalityAxis.WIDTH:
pad = (2, 0, 0, 1)
case CausalityAxis.HEIGHT:
pad = (0, 1, 2, 0)
case CausalityAxis.WIDTH_COMPATIBILITY:
pad = (1, 0, 0, 1)
case _:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
def build_downsampling_path( # noqa: PLR0913
*,
ch: int,
ch_mult: Tuple[int, ...],
num_resolutions: int,
num_res_blocks: int,
resolution: int,
temb_channels: int,
dropout: float,
norm_type: NormType,
causality_axis: CausalityAxis,
attn_type: AttentionType,
attn_resolutions: Set[int],
resamp_with_conv: bool,
) -> tuple[torch.nn.ModuleList, int]:
"""Build the downsampling path with residual blocks, attention, and downsampling layers."""
down_modules = torch.nn.ModuleList()
curr_res = resolution
in_ch_mult = (1, *tuple(ch_mult))
block_in = ch
for i_level in range(num_resolutions):
block = torch.nn.ModuleList()
attn = torch.nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
down = torch.nn.Module()
down.block = block
down.attn = attn
if i_level != num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
curr_res = curr_res // 2
down_modules.append(down)
return down_modules, block_in
class Upsample(torch.nn.Module):
def __init__(
self,
in_channels: int,
with_conv: bool,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
) -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
# So the output elements rely on the following windows:
# 0: [-,-,0]
# 1: [-,0,0]
# 2: [0,0,1]
# 3: [0,1,1]
# 4: [1,1,2]
# 5: [1,2,2]
# Notice that the first and second elements in the output rely only on the first element in the input,
# while all other elements rely on two elements in the input.
# So we can drop the first element to undo the padding (rather than the last element).
# This is a no-op for non-causal convolutions.
match self.causality_axis:
case CausalityAxis.NONE:
pass # x remains unchanged
case CausalityAxis.HEIGHT:
x = x[:, :, 1:, :]
case CausalityAxis.WIDTH:
x = x[:, :, :, 1:]
case CausalityAxis.WIDTH_COMPATIBILITY:
pass # x remains unchanged
case _:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
return x
def build_upsampling_path( # noqa: PLR0913
*,
ch: int,
ch_mult: Tuple[int, ...],
num_resolutions: int,
num_res_blocks: int,
resolution: int,
temb_channels: int,
dropout: float,
norm_type: NormType,
causality_axis: CausalityAxis,
attn_type: AttentionType,
attn_resolutions: Set[int],
resamp_with_conv: bool,
initial_block_channels: int,
) -> tuple[torch.nn.ModuleList, int]:
"""Build the upsampling path with residual blocks, attention, and upsampling layers."""
up_modules = torch.nn.ModuleList()
block_in = initial_block_channels
curr_res = resolution // (2 ** (num_resolutions - 1))
for level in reversed(range(num_resolutions)):
stage = torch.nn.Module()
stage.block = torch.nn.ModuleList()
stage.attn = torch.nn.ModuleList()
block_out = ch * ch_mult[level]
for _ in range(num_res_blocks + 1):
stage.block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
)
block_in = block_out
if curr_res in attn_resolutions:
stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
if level != 0:
stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
curr_res *= 2
up_modules.insert(0, stage)
return up_modules, block_in
class PerChannelStatistics(nn.Module):
"""
Per-channel statistics for normalizing and denormalizing the latent representation.
This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict.
"""
def __init__(self, latent_channels: int = 128) -> None:
super().__init__()
self.register_buffer("std-of-means", torch.empty(latent_channels))
self.register_buffer("mean-of-means", torch.empty(latent_channels))
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
def normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
LATENT_DOWNSAMPLE_FACTOR = 4
def build_mid_block(
channels: int,
temb_channels: int,
dropout: float,
norm_type: NormType,
causality_axis: CausalityAxis,
attn_type: AttentionType,
add_attention: bool,
) -> torch.nn.Module:
"""Build the middle block with two ResNet blocks and optional attention."""
mid = torch.nn.Module()
mid.block_1 = ResnetBlock(
in_channels=channels,
out_channels=channels,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()
mid.block_2 = ResnetBlock(
in_channels=channels,
out_channels=channels,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
return mid
def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:
"""Run features through the middle block."""
features = mid.block_1(features, temb=None)
features = mid.attn_1(features)
return mid.block_2(features, temb=None)
class LTX2AudioEncoder(torch.nn.Module):
"""
Encoder that compresses audio spectrograms into latent representations.
The encoder uses a series of downsampling blocks with residual connections,
attention mechanisms, and configurable causal convolutions.
"""
def __init__( # noqa: PLR0913
self,
*,
ch: int = 128,
ch_mult: Tuple[int, ...] = (1, 2, 4),
num_res_blocks: int = 2,
attn_resolutions: Set[int] = set(),
dropout: float = 0.0,
resamp_with_conv: bool = True,
in_channels: int = 2,
resolution: int = 256,
z_channels: int = 8,
double_z: bool = True,
attn_type: AttentionType = AttentionType.VANILLA,
mid_block_add_attention: bool = False,
norm_type: NormType = NormType.PIXEL,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
sample_rate: int = 16000,
mel_hop_length: int = 160,
n_fft: int = 1024,
is_causal: bool = True,
mel_bins: int = 64,
**_ignore_kwargs,
) -> None:
"""
Initialize the Encoder.
Args:
Arguments are configuration parameters, loaded from the audio VAE checkpoint config
(audio_vae.model.params.ddconfig):
ch: Base number of feature channels used in the first convolution layer.
ch_mult: Multiplicative factors for the number of channels at each resolution level.
num_res_blocks: Number of residual blocks to use at each resolution level.
attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.
resolution: Input spatial resolution of the spectrogram (height, width).
z_channels: Number of channels in the latent representation.
norm_type: Normalization layer type to use within the network (e.g., group, batch).
causality_axis: Axis along which convolutions should be causal (e.g., time axis).
sample_rate: Audio sample rate in Hz for the input signals.
mel_hop_length: Hop length used when computing the mel spectrogram.
n_fft: FFT size used to compute the spectrogram.
mel_bins: Number of mel-frequency bins in the input spectrogram.
in_channels: Number of channels in the input spectrogram tensor.
double_z: If True, predict both mean and log-variance (doubling latent channels).
is_causal: If True, use causal convolutions suitable for streaming setups.
dropout: Dropout probability used in residual and mid blocks.
attn_type: Type of attention mechanism to use in attention blocks.
resamp_with_conv: If True, perform resolution changes using strided convolutions.
mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.
"""
super().__init__()
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.n_fft = n_fft
self.is_causal = is_causal
self.mel_bins = mel_bins
self.patchifier = AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=sample_rate,
hop_length=mel_hop_length,
is_causal=is_causal,
)
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.z_channels = z_channels
self.double_z = double_z
self.norm_type = norm_type
self.causality_axis = causality_axis
self.attn_type = attn_type
# downsampling
self.conv_in = make_conv2d(
in_channels,
self.ch,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis,
)
self.non_linearity = torch.nn.SiLU()
self.down, block_in = build_downsampling_path(
ch=ch,
ch_mult=ch_mult,
num_resolutions=self.num_resolutions,
num_res_blocks=num_res_blocks,
resolution=resolution,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
attn_resolutions=attn_resolutions,
resamp_with_conv=resamp_with_conv,
)
self.mid = build_mid_block(
channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
add_attention=mid_block_add_attention,
)
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
self.conv_out = make_conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis,
)
def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
"""
Encode audio spectrogram into latent representations.
Args:
spectrogram: Input spectrogram of shape (batch, channels, time, frequency)
Returns:
Encoded latent representation of shape (batch, channels, frames, mel_bins)
"""
h = self.conv_in(spectrogram)
h = self._run_downsampling_path(h)
h = run_mid_block(self.mid, h)
h = self._finalize_output(h)
return self._normalize_latents(h)
def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:
for level in range(self.num_resolutions):
stage = self.down[level]
for block_idx in range(self.num_res_blocks):
h = stage.block[block_idx](h, temb=None)
if stage.attn:
h = stage.attn[block_idx](h)
if level != self.num_resolutions - 1:
h = stage.downsample(h)
return h
def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
h = self.norm_out(h)
h = self.non_linearity(h)
return self.conv_out(h)
def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:
"""
Normalize encoder latents using per-channel statistics.
When the encoder is configured with ``double_z=True``, the final
convolution produces twice the number of latent channels, typically
interpreted as two concatenated tensors along the channel dimension
(e.g., mean and variance or other auxiliary parameters).
This method intentionally uses only the first half of the channels
(the "mean" component) as input to the patchifier and normalization
logic. The remaining channels are left unchanged by this method and
are expected to be consumed elsewhere in the VAE pipeline.
If ``double_z=False``, the encoder output already contains only the
mean latents and the chunking operation simply returns that tensor.
"""
means = torch.chunk(latent_output, 2, dim=1)[0]
latent_shape = AudioLatentShape(
batch=means.shape[0],
channels=means.shape[1],
frames=means.shape[2],
mel_bins=means.shape[3],
)
latent_patched = self.patchifier.patchify(means)
latent_normalized = self.per_channel_statistics.normalize(latent_patched)
return self.patchifier.unpatchify(latent_normalized, latent_shape)
class LTX2AudioDecoder(torch.nn.Module):
"""
Symmetric decoder that reconstructs audio spectrograms from latent features.
The decoder mirrors the encoder structure with configurable channel multipliers,
attention resolutions, and causal convolutions.
"""
def __init__( # noqa: PLR0913
self,
*,
ch: int = 128,
out_ch: int = 2,
ch_mult: Tuple[int, ...] = (1, 2, 4),
num_res_blocks: int = 2,
attn_resolutions: Set[int] = set(),
resolution: int=256,
z_channels: int=8,
norm_type: NormType = NormType.PIXEL,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: int | None = 64,
) -> None:
"""
Initialize the Decoder.
Args:
Arguments are configuration parameters, loaded from the audio VAE checkpoint config
(audio_vae.model.params.ddconfig):
- ch, out_ch, ch_mult, num_res_blocks, attn_resolutions
- resolution, z_channels
- norm_type, causality_axis
"""
super().__init__()
# Internal behavioural defaults that are not driven by the checkpoint.
resamp_with_conv = True
attn_type = AttentionType.VANILLA
# Per-channel statistics for denormalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.patchifier = AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=sample_rate,
hop_length=mel_hop_length,
is_causal=is_causal,
)
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.out_ch = out_ch
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.z_channels = z_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
self.attn_type = attn_type
base_block_channels = ch * self.channel_multipliers[-1]
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, z_channels, base_resolution, base_resolution)
self.conv_in = make_conv2d(
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
self.non_linearity = torch.nn.SiLU()
self.mid = build_mid_block(
channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
add_attention=mid_block_add_attention,
)
self.up, final_block_channels = build_upsampling_path(
ch=ch,
ch_mult=ch_mult,
num_resolutions=self.num_resolutions,
num_res_blocks=num_res_blocks,
resolution=resolution,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
attn_resolutions=attn_resolutions,
resamp_with_conv=resamp_with_conv,
initial_block_channels=base_block_channels,
)
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
self.conv_out = make_conv2d(
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
"""
Decode latent features back to audio spectrograms.
Args:
sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)
Returns:
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
"""
sample, target_shape = self._denormalize_latents(sample)
h = self.conv_in(sample)
h = run_mid_block(self.mid, h)
h = self._run_upsampling_path(h)
h = self._finalize_output(h)
return self._adjust_output_shape(h, target_shape)
def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:
latent_shape = AudioLatentShape(
batch=sample.shape[0],
channels=sample.shape[1],
frames=sample.shape[2],
mel_bins=sample.shape[3],
)
sample_patched = self.patchifier.patchify(sample)
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
if self.causality_axis != CausalityAxis.NONE:
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
target_shape = AudioLatentShape(
batch=latent_shape.batch,
channels=self.out_ch,
frames=target_frames,
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
)
return sample, target_shape
def _adjust_output_shape(
self,
decoded_output: torch.Tensor,
target_shape: AudioLatentShape,
) -> torch.Tensor:
"""
Adjust output shape to match target dimensions for variable-length audio.
This function handles the common case where decoded audio spectrograms need to be
resized to match a specific target shape.
Args:
decoded_output: Tensor of shape (batch, channels, time, frequency)
target_shape: AudioLatentShape describing (batch, channels, time, mel bins)
Returns:
Tensor adjusted to match target_shape exactly
"""
# Current output shape: (batch, channels, time, frequency)
_, _, current_time, current_freq = decoded_output.shape
target_channels = target_shape.channels
target_time = target_shape.frames
target_freq = target_shape.mel_bins
# Step 1: Crop first to avoid exceeding target dimensions
decoded_output = decoded_output[
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
]
# Step 2: Calculate padding needed for time and frequency dimensions
time_padding_needed = target_time - decoded_output.shape[2]
freq_padding_needed = target_freq - decoded_output.shape[3]
# Step 3: Apply padding if needed
if time_padding_needed > 0 or freq_padding_needed > 0:
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
# For audio: pad_left/right = frequency, pad_top/bottom = time
padding = (
0,
max(freq_padding_needed, 0), # frequency padding (left, right)
0,
max(time_padding_needed, 0), # time padding (top, bottom)
)
decoded_output = F.pad(decoded_output, padding)
# Step 4: Final safety crop to ensure exact target shape
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
return decoded_output
def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:
for level in reversed(range(self.num_resolutions)):
stage = self.up[level]
for block_idx, block in enumerate(stage.block):
h = block(h, temb=None)
if stage.attn:
h = stage.attn[block_idx](h)
if level != 0 and hasattr(stage, "upsample"):
h = stage.upsample(h)
return h
def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
if self.give_pre_end:
return h
h = self.norm_out(h)
h = self.non_linearity(h)
h = self.conv_out(h)
return torch.tanh(h) if self.tanh_out else h
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
# ---------------------------------------------------------------------------
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
# Adopted from https://github.com/NVIDIA/BigVGAN
# ---------------------------------------------------------------------------
def _sinc(x: torch.Tensor) -> torch.Tensor:
return torch.where(
x == 0,
torch.tensor(1.0, device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x,
)
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
even = kernel_size % 2 == 0
half_size = kernel_size // 2
delta_f = 4 * half_width
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if amplitude > 50.0:
beta = 0.1102 * (amplitude - 8.7)
elif amplitude >= 21.0:
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
filter_ /= filter_.sum()
return filter_.view(1, 1, kernel_size)
class LowPassFilter1d(nn.Module):
def __init__(
self,
cutoff: float = 0.5,
half_width: float = 0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = "replicate",
kernel_size: int = 12,
) -> None:
super().__init__()
if cutoff < -0.0:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, n_channels, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)
class UpSample1d(nn.Module):
def __init__(
self,
ratio: int = 2,
kernel_size: int | None = None,
persistent: bool = True,
window_type: str = "kaiser",
) -> None:
super().__init__()
self.ratio = ratio
self.stride = ratio
if window_type == "hann":
# Hann-windowed sinc filter equivalent to torchaudio.functional.resample
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
self.kernel_size = 2 * width * ratio + 1
self.pad = width
self.pad_left = 2 * width * ratio
self.pad_right = self.kernel_size - ratio
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
else:
# Kaiser-windowed sinc filter (BigVGAN default).
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
sinc_filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size,
)
self.register_buffer("filter", sinc_filter, persistent=persistent)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, n_channels, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)
x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)
return x[..., self.pad_left : -self.pad_right]
class DownSample1d(nn.Module):
def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lowpass(x)
class Activation1d(nn.Module):
def __init__(
self,
activation: nn.Module,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
) -> None:
super().__init__()
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.upsample(x)
x = self.act(x)
return self.downsample(x)
class Snake(nn.Module):
def __init__(
self,
in_features: int,
alpha: float = 1.0,
alpha_trainable: bool = True,
alpha_logscale: bool = True,
) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.eps = 1e-9
def forward(self, x: torch.Tensor) -> torch.Tensor:
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)
class SnakeBeta(nn.Module):
def __init__(
self,
in_features: int,
alpha: float = 1.0,
alpha_trainable: bool = True,
alpha_logscale: bool = True,
) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
self.beta.requires_grad = alpha_trainable
self.eps = 1e-9
def forward(self, x: torch.Tensor) -> torch.Tensor:
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)
class AMPBlock1(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: tuple[int, int, int] = (1, 3, 5),
activation: str = "snake",
) -> None:
super().__init__()
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.convs1 = nn.ModuleList(
[
nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
),
]
)
self.convs2 = nn.ModuleList(
[
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
]
)
self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])
self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = x + xt
return x
class LTX2Vocoder(torch.nn.Module):
"""
LTX2Vocoder model for synthesizing audio from Mel spectrograms.
Args:
resblock_kernel_sizes: List of kernel sizes for the residual blocks.
This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
upsample_rates: List of upsampling rates.
This value is read from the checkpoint at `config.vocoder.upsample_rates`.
upsample_kernel_sizes: List of kernel sizes for the upsampling layers.
This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`.
resblock_dilation_sizes: List of dilation sizes for the residual blocks.
This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
upsample_initial_channel: Initial number of channels for the upsampling layers.
This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
resblock: Type of residual block to use ("1", "2", or "AMP1").
This value is read from the checkpoint at `config.vocoder.resblock`.
output_sampling_rate: Waveform sample rate.
This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.
activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1".
use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).
apply_final_activation: Whether to apply the final tanh/clamp activation.
use_bias_at_final: Whether to use bias in the final conv layer.
"""
def __init__( # noqa: PLR0913
self,
resblock_kernel_sizes: List[int] | None = [3, 7, 11],
upsample_rates: List[int] | None = [6, 5, 2, 2, 2],
upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4],
resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_initial_channel: int = 1024,
resblock: str = "1",
output_sampling_rate: int = 24000,
activation: str = "snake",
use_tanh_at_final: bool = True,
apply_final_activation: bool = True,
use_bias_at_final: bool = True,
) -> None:
super().__init__()
# Mutable default values are not supported as default arguments.
if resblock_kernel_sizes is None:
resblock_kernel_sizes = [3, 7, 11]
if upsample_rates is None:
upsample_rates = [6, 5, 2, 2, 2]
if upsample_kernel_sizes is None:
upsample_kernel_sizes = [16, 15, 8, 4, 4]
if resblock_dilation_sizes is None:
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
self.output_sampling_rate = output_sampling_rate
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.use_tanh_at_final = use_tanh_at_final
self.apply_final_activation = apply_final_activation
self.is_amp = resblock == "AMP1"
# All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel
# bins each), 2 output channels.
self.conv_pre = nn.Conv1d(
in_channels=128,
out_channels=upsample_initial_channel,
kernel_size=7,
stride=1,
padding=3,
)
resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1
self.ups = nn.ModuleList(
nn.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
kernel_size,
stride,
padding=(kernel_size - stride) // 2,
)
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))
)
final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
self.resblocks = nn.ModuleList()
for i in range(len(upsample_rates)):
ch = upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
if self.is_amp:
self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))
else:
self.resblocks.append(resblock_cls(ch, kernel_size, dilations))
if self.is_amp:
self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))
else:
self.act_post = nn.LeakyReLU()
# All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).
self.conv_post = nn.Conv1d(
in_channels=final_channels,
out_channels=2,
kernel_size=7,
stride=1,
padding=3,
bias=use_bias_at_final,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the vocoder.
Args:
x: Input Mel spectrogram tensor. Can be either:
- 3D: (batch_size, time, mel_bins) for mono
- 4D: (batch_size, 2, time, mel_bins) for stereo
Returns:
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
"""
x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time)
if x.dim() == 4: # stereo
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
x = einops.rearrange(x, "b s c t -> b (s c) t")
x = self.conv_pre(x)
for i in range(self.num_upsamples):
if not self.is_amp:
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
start = i * self.num_kernels
end = start + self.num_kernels
# Evaluate all resblocks with the same input tensor so they can run
# independently (and thus in parallel on accelerator hardware) before
# aggregating their outputs via mean.
block_outputs = torch.stack(
[self.resblocks[idx](x) for idx in range(start, end)],
dim=0,
)
x = block_outputs.mean(dim=0)
x = self.act_post(x)
x = self.conv_post(x)
if self.apply_final_activation:
x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)
return x
class _STFTFn(nn.Module):
"""Implements STFT as a convolution with precomputed DFT x Hann-window bases.
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
bfloat16 bases from training ensures the mel values fed to the BWE generator are
bit-identical to what it was trained on.
"""
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
super().__init__()
self.hop_length = hop_length
self.win_length = win_length
n_freqs = filter_length // 2 + 1
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute magnitude and phase spectrogram from a batch of waveforms.
Applies causal (left-only) padding of win_length - hop_length samples so that
each output frame depends only on past and present input — no lookahead.
Args:
y: Waveform tensor of shape (B, T).
Returns:
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
"""
if y.dim() == 2:
y = y.unsqueeze(1) # (B, 1, T)
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
y = F.pad(y, (left_pad, 0))
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
n_freqs = spec.shape[1] // 2
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
magnitude = torch.sqrt(real**2 + imag**2)
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
return magnitude, phase
class MelSTFT(nn.Module):
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
"""
def __init__(
self,
filter_length: int,
hop_length: int,
win_length: int,
n_mel_channels: int,
) -> None:
super().__init__()
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
# Initialized to zeros; load_state_dict overwrites with the checkpoint's
# exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).
n_freqs = filter_length // 2 + 1
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute log-mel spectrogram and auxiliary spectral quantities.
Args:
y: Waveform tensor of shape (B, T).
Returns:
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
"""
magnitude, phase = self.stft_fn(y)
energy = torch.norm(magnitude, dim=1)
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
log_mel = torch.log(torch.clamp(mel, min=1e-5))
return log_mel, magnitude, phase, energy
class LTX2VocoderWithBWE(nn.Module):
"""LTX2Vocoder with bandwidth extension (BWE) upsampling.
Chains a mel-to-wav vocoder with a BWE module that upsamples the output
to a higher sample rate. The BWE computes a mel spectrogram from the
vocoder output, runs it through a second generator to predict a residual,
and adds it to a sinc-resampled skip connection.
"""
def __init__(
self,
input_sampling_rate: int = 16000,
output_sampling_rate: int = 48000,
hop_length: int = 80,
) -> None:
super().__init__()
self.vocoder = LTX2Vocoder(
resblock_kernel_sizes=[3, 7, 11],
upsample_rates=[5, 2, 2, 2, 2, 2],
upsample_kernel_sizes=[11, 4, 4, 4, 4, 4],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_initial_channel=1536,
resblock="AMP1",
activation="snakebeta",
use_tanh_at_final=False,
apply_final_activation=True,
use_bias_at_final=False,
output_sampling_rate=input_sampling_rate,
)
self.bwe_generator = LTX2Vocoder(
resblock_kernel_sizes=[3, 7, 11],
upsample_rates=[6, 5, 2, 2, 2],
upsample_kernel_sizes=[12, 11, 4, 4, 4],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_initial_channel=512,
resblock="AMP1",
activation="snakebeta",
use_tanh_at_final=False,
apply_final_activation=False,
use_bias_at_final=False,
output_sampling_rate=output_sampling_rate,
)
self.mel_stft = MelSTFT(
filter_length=512,
hop_length=hop_length,
win_length=512,
n_mel_channels=64,
)
self.input_sampling_rate = input_sampling_rate
self.output_sampling_rate = output_sampling_rate
self.hop_length = hop_length
# Compute the resampler on CPU so the sinc filter is materialized even when
# the model is constructed on meta device (SingleGPUModelBuilder pattern).
# The filter is not stored in the checkpoint (persistent=False).
with torch.device("cpu"):
self.resampler = UpSample1d(
ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann"
)
@property
def conv_pre(self) -> nn.Conv1d:
return self.vocoder.conv_pre
@property
def conv_post(self) -> nn.Conv1d:
return self.vocoder.conv_post
def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
"""Compute log-mel spectrogram from waveform using causal STFT bases.
Args:
audio: Waveform tensor of shape (B, C, T).
Returns:
mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).
"""
batch, n_channels, _ = audio.shape
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
"""Run the full vocoder + BWE forward pass.
Args:
mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo
or (B, T, mel_bins) for mono. Same format as LTX2Vocoder.forward.
Returns:
Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].
"""
x = self.vocoder(mel_spec)
_, _, length_low_rate = x.shape
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
# Pad to multiple of hop_length for exact mel frame count
remainder = length_low_rate % self.hop_length
if remainder != 0:
x = F.pad(x, (0, self.hop_length - remainder))
# Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)
mel = self._compute_mel(x)
# LTX2Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator
mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins)
residual = self.bwe_generator(mel_for_bwe)
skip = self.resampler(x)
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
return torch.clamp(residual + skip, -1, 1)[..., :output_length]
|