diff --git a/LTA_openwebtext_dualt/logs/fullycoupled_outwd0p5_8gpu/lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_outwd0p5_nanogpt_tf32_ddit768x12_gbs512_8gpu_1m_20260514_215642.log b/LTA_openwebtext_dualt/logs/fullycoupled_outwd0p5_8gpu/lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_outwd0p5_nanogpt_tf32_ddit768x12_gbs512_8gpu_1m_20260514_215642.log new file mode 100644 index 0000000000000000000000000000000000000000..51f97cab1597aaa8df4cece1047d8a78d4ee57d0 --- /dev/null +++ b/LTA_openwebtext_dualt/logs/fullycoupled_outwd0p5_8gpu/lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_outwd0p5_nanogpt_tf32_ddit768x12_gbs512_8gpu_1m_20260514_215642.log @@ -0,0 +1,700 @@ +t-20260515055614-rg8sr-worker-0:10254:10254 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10254:10254 [0] NCCL INFO Bootstrap: Using eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10254:10254 [0] NCCL INFO cudaDriverVersion 12080 +t-20260515055614-rg8sr-worker-0:10254:10254 [0] NCCL INFO NCCL version 2.25.1+cuda12.8 +t-20260515055614-rg8sr-worker-0:10254:10254 [0] NCCL INFO Comm config Blocking set to 1 +t-20260515055614-rg8sr-worker-0:10257:10257 [3] NCCL INFO cudaDriverVersion 12080 +t-20260515055614-rg8sr-worker-0:10257:10257 [3] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10257:10257 [3] NCCL INFO Bootstrap: Using eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10257:10257 [3] NCCL INFO NCCL version 2.25.1+cuda12.8 +t-20260515055614-rg8sr-worker-0:10257:10257 [3] NCCL INFO Comm config Blocking set to 1 +t-20260515055614-rg8sr-worker-0:10260:10260 [6] NCCL INFO cudaDriverVersion 12080 +t-20260515055614-rg8sr-worker-0:10260:10260 [6] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10260:10260 [6] NCCL INFO Bootstrap: Using eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10260:10260 [6] NCCL INFO NCCL version 2.25.1+cuda12.8 +t-20260515055614-rg8sr-worker-0:10260:10260 [6] NCCL INFO Comm config Blocking set to 1 +t-20260515055614-rg8sr-worker-0:10261:10261 [7] NCCL INFO cudaDriverVersion 12080 +t-20260515055614-rg8sr-worker-0:10261:10261 [7] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10261:10261 [7] NCCL INFO Bootstrap: Using eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10261:10261 [7] NCCL INFO NCCL version 2.25.1+cuda12.8 +t-20260515055614-rg8sr-worker-0:10261:10261 [7] NCCL INFO Comm config Blocking set to 1 +t-20260515055614-rg8sr-worker-0:10259:10259 [5] NCCL INFO cudaDriverVersion 12080 +t-20260515055614-rg8sr-worker-0:10259:10259 [5] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10259:10259 [5] NCCL INFO Bootstrap: Using eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10259:10259 [5] NCCL INFO NCCL version 2.25.1+cuda12.8 +t-20260515055614-rg8sr-worker-0:10259:10259 [5] NCCL INFO Comm config Blocking set to 1 +t-20260515055614-rg8sr-worker-0:10255:10255 [1] NCCL INFO cudaDriverVersion 12080 +t-20260515055614-rg8sr-worker-0:10255:10255 [1] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10255:10255 [1] NCCL INFO Bootstrap: Using eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10255:10255 [1] NCCL INFO NCCL version 2.25.1+cuda12.8 +t-20260515055614-rg8sr-worker-0:10255:10255 [1] NCCL INFO Comm config Blocking set to 1 +t-20260515055614-rg8sr-worker-0:10258:10258 [4] NCCL INFO cudaDriverVersion 12080 +t-20260515055614-rg8sr-worker-0:10258:10258 [4] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10256:10256 [2] NCCL INFO cudaDriverVersion 12080 +t-20260515055614-rg8sr-worker-0:10256:10256 [2] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10258:10258 [4] NCCL INFO Bootstrap: Using eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10258:10258 [4] NCCL INFO NCCL version 2.25.1+cuda12.8 +t-20260515055614-rg8sr-worker-0:10256:10256 [2] NCCL INFO Bootstrap: Using eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10256:10256 [2] NCCL INFO NCCL version 2.25.1+cuda12.8 +t-20260515055614-rg8sr-worker-0:10258:10258 [4] NCCL INFO Comm config Blocking set to 1 +t-20260515055614-rg8sr-worker-0:10256:10256 [2] NCCL INFO Comm config Blocking set to 1 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin v9 (v9) +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO NET/Plugin: Loaded collnet plugin SHARP (v9) +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO P2P plugin v9 IBext_v9 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin v9 (v9) +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO NET/Plugin: Loaded collnet plugin SHARP (v9) +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO P2P plugin v9 IBext_v9 +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin v9 (v9) +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO NET/Plugin: Loaded collnet plugin SHARP (v9) +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO P2P plugin v9 IBext_v9 +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin v9 (v9) +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO NET/Plugin: Loaded collnet plugin SHARP (v9) +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO P2P plugin v9 IBext_v9 +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO NCCL_IB_PCI_RELAXED_ORDERING set by environment to 1. +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO NET/IB : Using [0]mlx5_1:1/RoCE [1]mlx5_4:1/RoCE [2]mlx5_5:1/RoCE [3]mlx5_6:1/RoCE [4]mlx5_7:1/RoCE [5]mlx5_8:1/RoCE [6]mlx5_9:1/RoCE [7]mlx5_10:1/RoCE [RO]; OOB eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so. +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Using network IBext_v9 +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO NCCL_IB_PCI_RELAXED_ORDERING set by environment to 1. +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO NET/IB : Using [0]mlx5_1:1/RoCE [1]mlx5_4:1/RoCE [2]mlx5_5:1/RoCE [3]mlx5_6:1/RoCE [4]mlx5_7:1/RoCE [5]mlx5_8:1/RoCE [6]mlx5_9:1/RoCE [7]mlx5_10:1/RoCE [RO]; OOB eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO NCCL_IB_PCI_RELAXED_ORDERING set by environment to 1. +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO NET/IB : Using [0]mlx5_1:1/RoCE [1]mlx5_4:1/RoCE [2]mlx5_5:1/RoCE [3]mlx5_6:1/RoCE [4]mlx5_7:1/RoCE [5]mlx5_8:1/RoCE [6]mlx5_9:1/RoCE [7]mlx5_10:1/RoCE [RO]; OOB eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so. +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO Using network IBext_v9 +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so. +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO Using network IBext_v9 +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO NCCL_IB_PCI_RELAXED_ORDERING set by environment to 1. +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO NET/IB : Using [0]mlx5_1:1/RoCE [1]mlx5_4:1/RoCE [2]mlx5_5:1/RoCE [3]mlx5_6:1/RoCE [4]mlx5_7:1/RoCE [5]mlx5_8:1/RoCE [6]mlx5_9:1/RoCE [7]mlx5_10:1/RoCE [RO]; OOB eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so. +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO Using network IBext_v9 +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin v9 (v9) +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO NET/Plugin: Loaded collnet plugin SHARP (v9) +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO P2P plugin v9 IBext_v9 +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO NCCL_IB_PCI_RELAXED_ORDERING set by environment to 1. +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO NET/IB : Using [0]mlx5_1:1/RoCE [1]mlx5_4:1/RoCE [2]mlx5_5:1/RoCE [3]mlx5_6:1/RoCE [4]mlx5_7:1/RoCE [5]mlx5_8:1/RoCE [6]mlx5_9:1/RoCE [7]mlx5_10:1/RoCE [RO]; OOB eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so. +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO Using network IBext_v9 +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin v9 (v9) +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO NET/Plugin: Loaded collnet plugin SHARP (v9) +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO P2P plugin v9 IBext_v9 +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO NCCL_IB_PCI_RELAXED_ORDERING set by environment to 1. +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO NET/IB : Using [0]mlx5_1:1/RoCE [1]mlx5_4:1/RoCE [2]mlx5_5:1/RoCE [3]mlx5_6:1/RoCE [4]mlx5_7:1/RoCE [5]mlx5_8:1/RoCE [6]mlx5_9:1/RoCE [7]mlx5_10:1/RoCE [RO]; OOB eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so. +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO Using network IBext_v9 +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin v9 (v9) +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO NET/Plugin: Loaded collnet plugin SHARP (v9) +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO P2P plugin v9 IBext_v9 +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin v9 (v9) +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO NET/Plugin: Loaded collnet plugin SHARP (v9) +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO P2P plugin v9 IBext_v9 +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth1 +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO NCCL_IB_PCI_RELAXED_ORDERING set by environment to 1. +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO NET/IB : Using [0]mlx5_1:1/RoCE [1]mlx5_4:1/RoCE [2]mlx5_5:1/RoCE [3]mlx5_6:1/RoCE [4]mlx5_7:1/RoCE [5]mlx5_8:1/RoCE [6]mlx5_9:1/RoCE [7]mlx5_10:1/RoCE [RO]; OOB eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so. +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO Using network IBext_v9 +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO NCCL_IB_PCI_RELAXED_ORDERING set by environment to 1. +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO NET/IB : Using [0]mlx5_1:1/RoCE [1]mlx5_4:1/RoCE [2]mlx5_5:1/RoCE [3]mlx5_6:1/RoCE [4]mlx5_7:1/RoCE [5]mlx5_8:1/RoCE [6]mlx5_9:1/RoCE [7]mlx5_10:1/RoCE [RO]; OOB eth1:10.82.40.49<0> +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so. +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO Using network IBext_v9 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO ncclCommInitRankConfig comm 0xb26c3d0 rank 0 nranks 8 cudaDev 0 nvmlDev 0 busId 65040 commId 0x721971c0e3ca11b3 - Init START +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO ncclCommInitRankConfig comm 0xb006800 rank 6 nranks 8 cudaDev 6 nvmlDev 6 busId 73020 commId 0x721971c0e3ca11b3 - Init START +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO ncclCommInitRankConfig comm 0xacaf050 rank 3 nranks 8 cudaDev 3 nvmlDev 3 busId 6b020 commId 0x721971c0e3ca11b3 - Init START +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO ncclCommInitRankConfig comm 0xa8667e0 rank 7 nranks 8 cudaDev 7 nvmlDev 7 busId 75020 commId 0x721971c0e3ca11b3 - Init START +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO RAS client listening socket at ::1<28028> +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO ncclCommInitRankConfig comm 0xa26a890 rank 5 nranks 8 cudaDev 5 nvmlDev 5 busId 71020 commId 0x721971c0e3ca11b3 - Init START +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO RAS client listening socket at ::1<28028> +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO ncclCommInitRankConfig comm 0xabf2300 rank 1 nranks 8 cudaDev 1 nvmlDev 1 busId 67020 commId 0x721971c0e3ca11b3 - Init START +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO RAS client listening socket at ::1<28028> +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO ncclCommInitRankConfig comm 0xa2ffc70 rank 2 nranks 8 cudaDev 2 nvmlDev 2 busId 69020 commId 0x721971c0e3ca11b3 - Init START +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO RAS client listening socket at ::1<28028> +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO RAS client listening socket at ::1<28028> +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO ncclCommInitRankConfig comm 0xb2ac220 rank 4 nranks 8 cudaDev 4 nvmlDev 4 busId 6f020 commId 0x721971c0e3ca11b3 - Init START +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO RAS client listening socket at ::1<28028> +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO RAS client listening socket at ::1<28028> +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO RAS client listening socket at ::1<28028> +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO Bootstrap timings total 0.204779 (create 0.000022, send 0.000068, recv 0.000104, ring 0.204170, delay 0.000001) +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO Bootstrap timings total 0.240300 (create 0.000021, send 0.000067, recv 0.239751, ring 0.000173, delay 0.000000) +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO Bootstrap timings total 0.137077 (create 0.000020, send 0.000071, recv 0.000043, ring 0.000141, delay 0.000000) +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO Bootstrap timings total 0.000632 (create 0.000018, send 0.000064, recv 0.000089, ring 0.000200, delay 0.000001) +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO Bootstrap timings total 0.002598 (create 0.000020, send 0.000069, recv 0.000123, ring 0.002079, delay 0.000001) +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO Bootstrap timings total 0.066493 (create 0.000022, send 0.000062, recv 0.064005, ring 0.002079, delay 0.000001) +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO Bootstrap timings total 0.251243 (create 0.000025, send 0.000083, recv 0.046508, ring 0.136663, delay 0.000001) +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Bootstrap timings total 0.290563 (create 0.000030, send 0.000068, recv 0.224176, ring 0.065935, delay 0.000001) +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO MNNVL busId 0x6f020 fabric UUID 0.0 cliqueId 0x0 state 3 healthMask 0x0 +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO MNNVL busId 0x67020 fabric UUID 0.0 cliqueId 0x0 state 3 healthMask 0x0 +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO MNNVL busId 0x75020 fabric UUID 0.0 cliqueId 0x0 state 3 healthMask 0x0 +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO MNNVL busId 0x6b020 fabric UUID 0.0 cliqueId 0x0 state 3 healthMask 0x0 +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO MNNVL busId 0x71020 fabric UUID 0.0 cliqueId 0x0 state 3 healthMask 0x0 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO MNNVL busId 0x65040 fabric UUID 0.0 cliqueId 0x0 state 3 healthMask 0x0 +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO MNNVL busId 0x69020 fabric UUID 0.0 cliqueId 0x0 state 3 healthMask 0x0 +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO MNNVL busId 0x73020 fabric UUID 0.0 cliqueId 0x0 state 3 healthMask 0x0 +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO NCCL_TOPO_FILE set by environment to /var/run/nvidia-topologyd/virtualTopology.xml +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO NCCL_TOPO_FILE set by environment to /var/run/nvidia-topologyd/virtualTopology.xml +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO NCCL_TOPO_FILE set by environment to /var/run/nvidia-topologyd/virtualTopology.xml +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO NCCL_TOPO_FILE set by environment to /var/run/nvidia-topologyd/virtualTopology.xml +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO NCCL_TOPO_FILE set by environment to /var/run/nvidia-topologyd/virtualTopology.xml +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO NCCL_TOPO_FILE set by environment to /var/run/nvidia-topologyd/virtualTopology.xml +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO NCCL_TOPO_FILE set by environment to /var/run/nvidia-topologyd/virtualTopology.xml +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO NCCL_TOPO_FILE set by environment to /var/run/nvidia-topologyd/virtualTopology.xml +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO Setting affinity for GPU 4 to 0fffff,ffffffff,ffffffff,fc000000,00000000,00000000 +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO NVLS multicast support is available on dev 4 +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO Setting affinity for GPU 3 to 03ffffff,ffffffff,ffffffff +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO Setting affinity for GPU 2 to 03ffffff,ffffffff,ffffffff +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO NVLS multicast support is available on dev 3 +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO Setting affinity for GPU 1 to 03ffffff,ffffffff,ffffffff +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO NVLS multicast support is available on dev 1 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Setting affinity for GPU 0 to 03ffffff,ffffffff,ffffffff +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO NVLS multicast support is available on dev 0 +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO NVLS multicast support is available on dev 2 +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO Setting affinity for GPU 7 to 0fffff,ffffffff,ffffffff,fc000000,00000000,00000000 +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO NVLS multicast support is available on dev 7 +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO Setting affinity for GPU 6 to 0fffff,ffffffff,ffffffff,fc000000,00000000,00000000 +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO NVLS multicast support is available on dev 6 +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO Setting affinity for GPU 5 to 0fffff,ffffffff,ffffffff,fc000000,00000000,00000000 +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO NVLS multicast support is available on dev 5 +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO comm 0xb006800 rank 6 nRanks 8 nNodes 1 localRanks 8 localRank 6 MNNVL 0 +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO comm 0xa26a890 rank 5 nRanks 8 nNodes 1 localRanks 8 localRank 5 MNNVL 0 +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO comm 0xb2ac220 rank 4 nRanks 8 nNodes 1 localRanks 8 localRank 4 MNNVL 0 +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO comm 0xa2ffc70 rank 2 nRanks 8 nNodes 1 localRanks 8 localRank 2 MNNVL 0 +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO comm 0xacaf050 rank 3 nRanks 8 nNodes 1 localRanks 8 localRank 3 MNNVL 0 +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO comm 0xabf2300 rank 1 nRanks 8 nNodes 1 localRanks 8 localRank 1 MNNVL 0 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO comm 0xb26c3d0 rank 0 nRanks 8 nNodes 1 localRanks 8 localRank 0 MNNVL 0 +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO comm 0xa8667e0 rank 7 nRanks 8 nNodes 1 localRanks 8 localRank 7 MNNVL 0 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 00/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 01/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO Trees [0] 7/-1/-1->6->5 [1] 7/-1/-1->6->5 [2] 7/-1/-1->6->5 [3] 7/-1/-1->6->5 [4] 7/-1/-1->6->5 [5] 7/-1/-1->6->5 [6] 7/-1/-1->6->5 [7] 7/-1/-1->6->5 [8] 7/-1/-1->6->5 [9] 7/-1/-1->6->5 [10] 7/-1/-1->6->5 [11] 7/-1/-1->6->5 [12] 7/-1/-1->6->5 [13] 7/-1/-1->6->5 [14] 7/-1/-1->6->5 [15] 7/-1/-1->6->5 [16] 7/-1/-1->6->5 [17] 7/-1/-1->6->5 [18] 7/-1/-1->6->5 [19] 7/-1/-1->6->5 [20] 7/-1/-1->6->5 [21] 7/-1/-1->6->5 [22] 7/-1/-1->6->5 [23] 7/-1/-1->6->5 +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO Trees [0] 5/-1/-1->4->3 [1] 5/-1/-1->4->3 [2] 5/-1/-1->4->3 [3] 5/-1/-1->4->3 [4] 5/-1/-1->4->3 [5] 5/-1/-1->4->3 [6] 5/-1/-1->4->3 [7] 5/-1/-1->4->3 [8] 5/-1/-1->4->3 [9] 5/-1/-1->4->3 [10] 5/-1/-1->4->3 [11] 5/-1/-1->4->3 [12] 5/-1/-1->4->3 [13] 5/-1/-1->4->3 [14] 5/-1/-1->4->3 [15] 5/-1/-1->4->3 [16] 5/-1/-1->4->3 [17] 5/-1/-1->4->3 [18] 5/-1/-1->4->3 [19] 5/-1/-1->4->3 [20] 5/-1/-1->4->3 [21] 5/-1/-1->4->3 [22] 5/-1/-1->4->3 [23] 5/-1/-1->4->3 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 02/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO Trees [0] 4/-1/-1->3->2 [1] 4/-1/-1->3->2 [2] 4/-1/-1->3->2 [3] 4/-1/-1->3->2 [4] 4/-1/-1->3->2 [5] 4/-1/-1->3->2 [6] 4/-1/-1->3->2 [7] 4/-1/-1->3->2 [8] 4/-1/-1->3->2 [9] 4/-1/-1->3->2 [10] 4/-1/-1->3->2 [11] 4/-1/-1->3->2 [12] 4/-1/-1->3->2 [13] 4/-1/-1->3->2 [14] 4/-1/-1->3->2 [15] 4/-1/-1->3->2 [16] 4/-1/-1->3->2 [17] 4/-1/-1->3->2 [18] 4/-1/-1->3->2 [19] 4/-1/-1->3->2 [20] 4/-1/-1->3->2 [21] 4/-1/-1->3->2 [22] 4/-1/-1->3->2 [23] 4/-1/-1->3->2 +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO Trees [0] 2/-1/-1->1->0 [1] 2/-1/-1->1->0 [2] 2/-1/-1->1->0 [3] 2/-1/-1->1->0 [4] 2/-1/-1->1->0 [5] 2/-1/-1->1->0 [6] 2/-1/-1->1->0 [7] 2/-1/-1->1->0 [8] 2/-1/-1->1->0 [9] 2/-1/-1->1->0 [10] 2/-1/-1->1->0 [11] 2/-1/-1->1->0 [12] 2/-1/-1->1->0 [13] 2/-1/-1->1->0 [14] 2/-1/-1->1->0 [15] 2/-1/-1->1->0 [16] 2/-1/-1->1->0 [17] 2/-1/-1->1->0 [18] 2/-1/-1->1->0 [19] 2/-1/-1->1->0 [20] 2/-1/-1->1->0 [21] 2/-1/-1->1->0 [22] 2/-1/-1->1->0 [23] 2/-1/-1->1->0 +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO Trees [0] 6/-1/-1->5->4 [1] 6/-1/-1->5->4 [2] 6/-1/-1->5->4 [3] 6/-1/-1->5->4 [4] 6/-1/-1->5->4 [5] 6/-1/-1->5->4 [6] 6/-1/-1->5->4 [7] 6/-1/-1->5->4 [8] 6/-1/-1->5->4 [9] 6/-1/-1->5->4 [10] 6/-1/-1->5->4 [11] 6/-1/-1->5->4 [12] 6/-1/-1->5->4 [13] 6/-1/-1->5->4 [14] 6/-1/-1->5->4 [15] 6/-1/-1->5->4 [16] 6/-1/-1->5->4 [17] 6/-1/-1->5->4 [18] 6/-1/-1->5->4 [19] 6/-1/-1->5->4 [20] 6/-1/-1->5->4 [21] 6/-1/-1->5->4 [22] 6/-1/-1->5->4 [23] 6/-1/-1->5->4 +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO Trees [0] 3/-1/-1->2->1 [1] 3/-1/-1->2->1 [2] 3/-1/-1->2->1 [3] 3/-1/-1->2->1 [4] 3/-1/-1->2->1 [5] 3/-1/-1->2->1 [6] 3/-1/-1->2->1 [7] 3/-1/-1->2->1 [8] 3/-1/-1->2->1 [9] 3/-1/-1->2->1 [10] 3/-1/-1->2->1 [11] 3/-1/-1->2->1 [12] 3/-1/-1->2->1 [13] 3/-1/-1->2->1 [14] 3/-1/-1->2->1 [15] 3/-1/-1->2->1 [16] 3/-1/-1->2->1 [17] 3/-1/-1->2->1 [18] 3/-1/-1->2->1 [19] 3/-1/-1->2->1 [20] 3/-1/-1->2->1 [21] 3/-1/-1->2->1 [22] 3/-1/-1->2->1 [23] 3/-1/-1->2->1 +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO P2P Chunksize set to 524288 +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO P2P Chunksize set to 524288 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 03/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO P2P Chunksize set to 524288 +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO P2P Chunksize set to 524288 +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO P2P Chunksize set to 524288 +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO P2P Chunksize set to 524288 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 04/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 05/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 06/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 07/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 08/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO Trees [0] -1/-1/-1->7->6 [1] -1/-1/-1->7->6 [2] -1/-1/-1->7->6 [3] -1/-1/-1->7->6 [4] -1/-1/-1->7->6 [5] -1/-1/-1->7->6 [6] -1/-1/-1->7->6 [7] -1/-1/-1->7->6 [8] -1/-1/-1->7->6 [9] -1/-1/-1->7->6 [10] -1/-1/-1->7->6 [11] -1/-1/-1->7->6 [12] -1/-1/-1->7->6 [13] -1/-1/-1->7->6 [14] -1/-1/-1->7->6 [15] -1/-1/-1->7->6 [16] -1/-1/-1->7->6 [17] -1/-1/-1->7->6 [18] -1/-1/-1->7->6 [19] -1/-1/-1->7->6 [20] -1/-1/-1->7->6 [21] -1/-1/-1->7->6 [22] -1/-1/-1->7->6 [23] -1/-1/-1->7->6 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 09/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 10/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO P2P Chunksize set to 524288 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 11/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 12/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 13/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 14/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 15/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 16/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 17/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 18/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 19/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 20/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 21/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 22/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Channel 23/24 : 0 1 2 3 4 5 6 7 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1 [2] 1/-1/-1->0->-1 [3] 1/-1/-1->0->-1 [4] 1/-1/-1->0->-1 [5] 1/-1/-1->0->-1 [6] 1/-1/-1->0->-1 [7] 1/-1/-1->0->-1 [8] 1/-1/-1->0->-1 [9] 1/-1/-1->0->-1 [10] 1/-1/-1->0->-1 [11] 1/-1/-1->0->-1 [12] 1/-1/-1->0->-1 [13] 1/-1/-1->0->-1 [14] 1/-1/-1->0->-1 [15] 1/-1/-1->0->-1 [16] 1/-1/-1->0->-1 [17] 1/-1/-1->0->-1 [18] 1/-1/-1->0->-1 [19] 1/-1/-1->0->-1 [20] 1/-1/-1->0->-1 [21] 1/-1/-1->0->-1 [22] 1/-1/-1->0->-1 [23] 1/-1/-1->0->-1 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO P2P Chunksize set to 524288 +t-20260515055614-rg8sr-worker-0:10258:10406 [4] NCCL INFO [Proxy Service] Device 4 CPU core 95 +t-20260515055614-rg8sr-worker-0:10258:10408 [4] NCCL INFO [Proxy Service UDS] Device 4 CPU core 97 +t-20260515055614-rg8sr-worker-0:10260:10407 [6] NCCL INFO [Proxy Service] Device 6 CPU core 96 +t-20260515055614-rg8sr-worker-0:10260:10409 [6] NCCL INFO [Proxy Service UDS] Device 6 CPU core 98 +t-20260515055614-rg8sr-worker-0:10259:10410 [5] NCCL INFO [Proxy Service] Device 5 CPU core 100 +t-20260515055614-rg8sr-worker-0:10259:10411 [5] NCCL INFO [Proxy Service UDS] Device 5 CPU core 102 +t-20260515055614-rg8sr-worker-0:10257:10412 [3] NCCL INFO [Proxy Service] Device 3 CPU core 2 +t-20260515055614-rg8sr-worker-0:10257:10413 [3] NCCL INFO [Proxy Service UDS] Device 3 CPU core 6 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Check P2P Type intraNodeP2pSupport 1 directMode 0 +t-20260515055614-rg8sr-worker-0:10256:10414 [2] NCCL INFO [Proxy Service] Device 2 CPU core 68 +t-20260515055614-rg8sr-worker-0:10256:10416 [2] NCCL INFO [Proxy Service UDS] Device 2 CPU core 70 +t-20260515055614-rg8sr-worker-0:10261:10415 [7] NCCL INFO [Proxy Service] Device 7 CPU core 144 +t-20260515055614-rg8sr-worker-0:10254:10418 [0] NCCL INFO [Proxy Service] Device 0 CPU core 64 +t-20260515055614-rg8sr-worker-0:10261:10417 [7] NCCL INFO [Proxy Service UDS] Device 7 CPU core 146 +t-20260515055614-rg8sr-worker-0:10254:10419 [0] NCCL INFO [Proxy Service UDS] Device 0 CPU core 70 +t-20260515055614-rg8sr-worker-0:10255:10421 [1] NCCL INFO [Proxy Service UDS] Device 1 CPU core 64 +t-20260515055614-rg8sr-worker-0:10255:10420 [1] NCCL INFO [Proxy Service] Device 1 CPU core 61 +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512 +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO 24 coll channels, 24 collnet channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512 +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO 24 coll channels, 24 collnet channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512 +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO 24 coll channels, 24 collnet channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512 +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO 24 coll channels, 24 collnet channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512 +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO 24 coll channels, 24 collnet channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512 +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO 24 coll channels, 24 collnet channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512 +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO 24 coll channels, 24 collnet channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512 +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO 24 coll channels, 24 collnet channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO CC Off, workFifoBytes 1048576 +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol. +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol. +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol. +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol. +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol. +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol. +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol. +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol. +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol. +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol. +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead. +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead. +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead. +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol. +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol. +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead. +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol. +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol. +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO ncclCommInitRankConfig comm 0xacaf050 rank 3 nranks 8 cudaDev 3 nvmlDev 3 busId 6b020 commId 0x721971c0e3ca11b3 - Init COMPLETE +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO ncclCommInitRankConfig comm 0xa2ffc70 rank 2 nranks 8 cudaDev 2 nvmlDev 2 busId 69020 commId 0x721971c0e3ca11b3 - Init COMPLETE +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead. +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol. +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO ncclCommInitRankConfig comm 0xb006800 rank 6 nranks 8 cudaDev 6 nvmlDev 6 busId 73020 commId 0x721971c0e3ca11b3 - Init COMPLETE +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO ncclCommInitRankConfig comm 0xb2ac220 rank 4 nranks 8 cudaDev 4 nvmlDev 4 busId 6f020 commId 0x721971c0e3ca11b3 - Init COMPLETE +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO ncclCommInitRankConfig comm 0xa26a890 rank 5 nranks 8 cudaDev 5 nvmlDev 5 busId 71020 commId 0x721971c0e3ca11b3 - Init COMPLETE +t-20260515055614-rg8sr-worker-0:10257:10323 [3] NCCL INFO Init timings - ncclCommInitRankConfig: rank 3 nranks 8 total 2.18 (kernels 0.19, alloc 0.81, bootstrap 0.24, allgathers 0.01, topo 0.53, graphs 0.01, connections 0.37, rest 0.02) +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead. +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol. +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead. +t-20260515055614-rg8sr-worker-0:10256:10333 [2] NCCL INFO Init timings - ncclCommInitRankConfig: rank 2 nranks 8 total 2.08 (kernels 0.51, alloc 0.64, bootstrap 0.00, allgathers 0.01, topo 0.53, graphs 0.01, connections 0.37, rest 0.02) +t-20260515055614-rg8sr-worker-0:10259:10330 [5] NCCL INFO Init timings - ncclCommInitRankConfig: rank 5 nranks 8 total 2.11 (kernels 0.23, alloc 0.81, bootstrap 0.14, allgathers 0.01, topo 0.53, graphs 0.01, connections 0.37, rest 0.02) +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO ncclCommInitRankConfig comm 0xabf2300 rank 1 nranks 8 cudaDev 1 nvmlDev 1 busId 67020 commId 0x721971c0e3ca11b3 - Init COMPLETE +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead. +t-20260515055614-rg8sr-worker-0:10260:10324 [6] NCCL INFO Init timings - ncclCommInitRankConfig: rank 6 nranks 8 total 2.18 (kernels 0.19, alloc 0.80, bootstrap 0.25, allgathers 0.01, topo 0.53, graphs 0.01, connections 0.37, rest 0.02) +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO ncclCommInitRankConfig comm 0xa8667e0 rank 7 nranks 8 cudaDev 7 nvmlDev 7 busId 75020 commId 0x721971c0e3ca11b3 - Init COMPLETE +t-20260515055614-rg8sr-worker-0:10258:10332 [4] NCCL INFO Init timings - ncclCommInitRankConfig: rank 4 nranks 8 total 2.08 (kernels 0.53, alloc 0.62, bootstrap 0.00, allgathers 0.01, topo 0.53, graphs 0.01, connections 0.37, rest 0.02) +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO ncclCommInitRankConfig comm 0xb26c3d0 rank 0 nranks 8 cudaDev 0 nvmlDev 0 busId 65040 commId 0x721971c0e3ca11b3 - Init COMPLETE +t-20260515055614-rg8sr-worker-0:10261:10329 [7] NCCL INFO Init timings - ncclCommInitRankConfig: rank 7 nranks 8 total 2.15 (kernels 0.19, alloc 0.82, bootstrap 0.20, allgathers 0.01, topo 0.53, graphs 0.01, connections 0.37, rest 0.02) +t-20260515055614-rg8sr-worker-0:10255:10331 [1] NCCL INFO Init timings - ncclCommInitRankConfig: rank 1 nranks 8 total 2.10 (kernels 0.34, alloc 0.76, bootstrap 0.07, allgathers 0.01, topo 0.53, graphs 0.01, connections 0.36, rest 0.03) +t-20260515055614-rg8sr-worker-0:10254:10320 [0] NCCL INFO Init timings - ncclCommInitRankConfig: rank 0 nranks 8 total 2.19 (kernels 0.20, alloc 0.77, bootstrap 0.29, allgathers 0.01, topo 0.53, graphs 0.01, connections 0.37, rest 0.02) +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 00/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 01/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 02/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 03/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 02/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 04/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 03/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 05/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 04/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 06/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 05/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 07/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 06/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 08/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 00/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 07/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 09/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 01/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 08/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 10/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 02/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 09/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 11/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 00/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 03/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 10/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 12/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 00/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 01/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 04/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 11/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 13/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 01/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 02/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 05/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 12/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 14/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 02/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 03/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 00/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 06/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 13/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 15/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 00/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 03/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 04/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 01/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 07/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 14/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 16/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 01/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 04/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 05/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 02/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 08/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 15/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 17/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 02/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 05/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 06/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 03/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 09/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 16/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 18/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 03/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 00/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 06/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 07/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 04/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 10/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 17/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 19/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 04/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 01/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 07/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 08/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 08/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 09/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 05/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 11/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 18/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 20/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 05/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 10/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 02/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 09/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 06/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 12/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 19/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 21/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 03/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 06/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 11/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 10/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 07/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 13/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 20/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 22/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 04/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 12/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 07/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 11/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 08/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 14/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 21/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Channel 23/0 : 2[2] -> 3[3] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 05/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 13/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 08/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 12/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 09/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 15/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 22/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 06/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 14/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 09/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 13/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 10/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 16/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Channel 23/0 : 0[0] -> 1[1] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 07/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 15/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 10/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 14/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 11/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 17/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 08/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 16/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 11/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 15/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 12/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 18/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 09/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 17/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 12/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 16/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 13/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 19/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 10/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 18/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 13/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 17/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 14/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 20/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 11/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 19/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 14/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 18/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 21/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 15/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 12/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 20/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 15/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 19/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 22/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 16/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 13/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 21/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 16/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 20/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Channel 23/0 : 7[7] -> 0[0] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 17/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 14/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 22/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 17/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 21/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 18/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 15/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Channel 23/0 : 4[4] -> 5[5] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 18/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 22/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 19/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 16/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 19/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 20/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Channel 23/0 : 6[6] -> 7[7] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 17/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 20/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 21/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 18/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 21/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 22/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 19/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 22/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Channel 23/0 : 3[3] -> 4[4] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 20/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Channel 23/0 : 5[5] -> 6[6] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 21/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 22/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Channel 23/0 : 1[1] -> 2[2] via P2P/CUMEM +t-20260515055614-rg8sr-worker-0:10260:10425 [6] NCCL INFO Connected all rings, use ring PXN 0 GDR 1 +t-20260515055614-rg8sr-worker-0:10256:10424 [2] NCCL INFO Connected all rings, use ring PXN 0 GDR 1 +t-20260515055614-rg8sr-worker-0:10259:10428 [5] NCCL INFO Connected all rings, use ring PXN 0 GDR 1 +t-20260515055614-rg8sr-worker-0:10255:10429 [1] NCCL INFO Connected all rings, use ring PXN 0 GDR 1 +t-20260515055614-rg8sr-worker-0:10254:10422 [0] NCCL INFO Connected all rings, use ring PXN 0 GDR 1 +t-20260515055614-rg8sr-worker-0:10261:10423 [7] NCCL INFO Connected all rings, use ring PXN 0 GDR 1 +t-20260515055614-rg8sr-worker-0:10257:10427 [3] NCCL INFO Connected all rings, use ring PXN 0 GDR 1 +t-20260515055614-rg8sr-worker-0:10258:10426 [4] NCCL INFO Connected all rings, use ring PXN 0 GDR 1 +{ + "device": "cuda:0", + "rank": 0, + "world_size": 8, + "samples": "owt_cached_chunks:8734897", + "vocab_size": 50257, + "tokenizer_vocab_size": 50257, + "save_dir": "runs/lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_outwd0p5_nanogpt_tf32_ddit768x12_gbs512_8gpu_1m_20260514_215642", + "batch_size": 32, + "grad_accum": 2, + "effective_batch_size": 512, + "global_batch_size": 512, + "lr_schedule": "cosine", + "optimizer": "adamw", + "warmup_steps": 2000, + "min_lr": 6e-05, + "weight_decay": 0.1, + "output_weight_decay": 0.5, + "adamw_param_groups": "nanogpt", + "adam_beta1": 0.9, + "adam_beta2": 0.95, + "adam_eps": 1e-08, + "muon_momentum": 0.95, + "muon_ns_steps": 5, + "muon_update_scale": 1.0, + "ema_decay": 0.0, + "ema_start_step": 0, + "model_type": "ddit", + "output_bias": false, + "norm_type": "rmsnorm", + "dual_t": true, + "corrupt_t_mode": "same", + "corrupt_min_t": 0.0, + "corrupt_max_t": 1.0, + "prefix_block_prob": 0.0, + "prefix_block_len": 128, + "mask_ratio_floor_schedule": "none", + "dirichlet_endpoint_mode": "categorical_dual_t", + "dirichlet_semantic_t_mode": "same", + "dirichlet_semantic_t_value": 0.0, + "dirichlet_semantic_t_curve": "linear", + "dirichlet_semantic_t_power": 1.0, + "endpoint_sequence_random_prob_alpha": 0.0, + "categorical_wrong_from_full_vocab": true, + "categorical_wrong_from_batch_valid_tokens": false, + "categorical_wrong_basin_token_ids": "", + "categorical_wrong_basin_prob": 0.0, + "categorical_wrong_unigram_prob": 0.0, + "categorical_wrong_uniform_prob": 0.0, + "categorical_wrong_corpus_unigram_path": "", + "categorical_wrong_corpus_unigram_alpha": 1.0, + "categorical_wrong_basin_shared_prob": 0.0, + "categorical_wrong_unigram_shared_prob": 0.0, + "mask_mixture_original_prob": 0.0, + "mask_mixture_lowk_prob": 0.0, + "mask_mixture_lowcorrupt_prob": 0.0, + "mask_mixture_block_prob": 0.0, + "mask_mixture_all_prob": 0.0, + "mask_mixture_lowk_clean_tokens": "1,2,4,8,16,32,64", + "mask_mixture_lowcorrupt_tokens": "1,2,4,8,16,32,64", + "mask_mixture_block_tokens": "64,128", + "simplex_bridge_sampler": "dirichlet", + "logistic_normal_sigma_min": 0.18, + "logistic_normal_sigma_max": 2.2, + "logistic_normal_tau_min": 0.65, + "logistic_normal_tau_max": 1.15, + "torch_compile": false, + "compile_mode": "max-autotune", + "state_format": "prob", + "target_loss": "hard_ce", + "meanflow_weight": 0.0, + "rollout_train_prob": 0.0, + "rollout_train_steps": 1, + "rollout_train_infer_steps": 64, + "rollout_train_temp": 1.45, + "rollout_train_max_gamma": 1.0, + "rollout_train_corrupt_only": true, + "rollout_train_samplewise": false, + "rollout_train_compute_always": false, + "bridge_noise_init": "logistic_normal", + "noise_sigma": -1.0, + "allow_tf32": true, + "activation_checkpointing": false, + "activation_checkpoint_interval": 1, + "activation_checkpoint_scope": "block", + "ddp_static_graph": false, + "ddp_gradient_as_bucket_view": true, + "blocking_data_transfer": false, + "dataloader_prefetch_factor": 4, + "full_train_stats": false, + "record_pad_truncate": false, + "record_add_eos": false, + "record_add_special_tokens": false, + "record_pad_token": "pad", + "record_shuffle_buffer": 10000, + "wrap": true, + "wrap_mode": "stream", + "wrap_record_buffer_size": 200, + "owt_cached_chunks": true, + "owt_chunk_cache_dir": "/e2e-data/evad-tech-vla/wanghan58/data/small_benchmarks/langflow_2604_11748/openwebtext_lta_cached_chunks/gpt2_len1024_train_minus_100k", + "owt_chunk_cache_rebuild": false, + "owt_chunk_cache_write_batch": 4096, + "owt_exact_repeat_per_chunk": 0, + "online_chunk_shuffle": false, + "online_chunk_shuffle_buffer": 10000, + "openwebtext_split": "train_minus_100k", + "detokenizer": "auto", + "resolved_detokenizer": null, + "num_workers": 8, + "latest_every": 1000, + "resume_path": "" +} +t-20260515055614-rg8sr-worker-0:10259:10782 [5] NCCL INFO NVLS comm 0xa26a890 headRank 5 nHeads 8 buffSize 1048576 nvlsPerRankSize 33554432 nvlsTotalSize 268435456 +t-20260515055614-rg8sr-worker-0:10257:10783 [3] NCCL INFO NVLS comm 0xacaf050 headRank 3 nHeads 8 buffSize 1048576 nvlsPerRankSize 33554432 nvlsTotalSize 268435456 +t-20260515055614-rg8sr-worker-0:10256:10784 [2] NCCL INFO NVLS comm 0xa2ffc70 headRank 2 nHeads 8 buffSize 1048576 nvlsPerRankSize 33554432 nvlsTotalSize 268435456 +t-20260515055614-rg8sr-worker-0:10261:10785 [7] NCCL INFO NVLS comm 0xa8667e0 headRank 7 nHeads 8 buffSize 1048576 nvlsPerRankSize 33554432 nvlsTotalSize 268435456 +t-20260515055614-rg8sr-worker-0:10258:10786 [4] NCCL INFO NVLS comm 0xb2ac220 headRank 4 nHeads 8 buffSize 1048576 nvlsPerRankSize 33554432 nvlsTotalSize 268435456 +t-20260515055614-rg8sr-worker-0:10255:10787 [1] NCCL INFO NVLS comm 0xabf2300 headRank 1 nHeads 8 buffSize 1048576 nvlsPerRankSize 33554432 nvlsTotalSize 268435456 +t-20260515055614-rg8sr-worker-0:10260:10788 [6] NCCL INFO NVLS comm 0xb006800 headRank 6 nHeads 8 buffSize 1048576 nvlsPerRankSize 33554432 nvlsTotalSize 268435456 +t-20260515055614-rg8sr-worker-0:10254:10789 [0] NCCL INFO NVLS comm 0xb26c3d0 headRank 0 nHeads 8 buffSize 1048576 nvlsPerRankSize 33554432 nvlsTotalSize 268435456 +step=50 micro_steps=100 elapsed=46.0s lr=1.530000e-05 loss=10.7758 loss_recon=10.7758 loss_meanflow=0.0000 mean_model_t=0.4965 mean_corrupt_t=0.4965 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.6742 corrupt_frac=0.5400 acc_corrupt=0.4430 loss_corrupt=10.7758 wrong_frac=0.5076 init_acc_corrupt=0.4583 acc_corrupt_t_0p0_0p2=0.0346 corrupt_frac_t_0p0_0p2=0.2039 acc_corrupt_t_0p2_0p4=0.2356 corrupt_frac_t_0p2_0p4=0.2009 acc_corrupt_t_0p4_0p6=0.4615 corrupt_frac_t_0p4_0p6=0.2110 acc_corrupt_t_0p6_0p8=0.6572 corrupt_frac_t_0p6_0p8=0.1946 acc_corrupt_t_0p8_1p0=0.8615 corrupt_frac_t_0p8_1p0=0.1896 out_w_norm=0.4795 out_g_norm=0.9776 loss_all=10.5919 init_gold_top10=0.5337 init_gold_top100=0.5571 +step=100 micro_steps=200 elapsed=45.2s lr=3.030000e-05 loss=10.1302 loss_recon=10.1302 loss_meanflow=0.0000 mean_model_t=0.4949 mean_corrupt_t=0.4949 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2092 corrupt_frac=0.5482 acc_corrupt=0.1203 loss_corrupt=10.1302 wrong_frac=0.5067 init_acc_corrupt=0.4576 acc_corrupt_t_0p0_0p2=0.0372 corrupt_frac_t_0p0_0p2=0.2143 acc_corrupt_t_0p2_0p4=0.0496 corrupt_frac_t_0p2_0p4=0.1985 acc_corrupt_t_0p4_0p6=0.0938 corrupt_frac_t_0p4_0p6=0.1922 acc_corrupt_t_0p6_0p8=0.1668 corrupt_frac_t_0p6_0p8=0.1921 acc_corrupt_t_0p8_1p0=0.2582 corrupt_frac_t_0p8_1p0=0.2029 out_w_norm=4.3649 out_g_norm=1.8101 loss_all=9.5210 init_gold_top10=0.4932 init_gold_top100=0.5131 +step=150 micro_steps=300 elapsed=45.2s lr=4.530000e-05 loss=8.8480 loss_recon=8.8480 loss_meanflow=0.0000 mean_model_t=0.4941 mean_corrupt_t=0.4941 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1022 corrupt_frac=0.5488 acc_corrupt=0.0690 loss_corrupt=8.8480 wrong_frac=0.5008 init_acc_corrupt=0.4652 acc_corrupt_t_0p0_0p2=0.0353 corrupt_frac_t_0p0_0p2=0.1904 acc_corrupt_t_0p2_0p4=0.0368 corrupt_frac_t_0p2_0p4=0.2073 acc_corrupt_t_0p4_0p6=0.0511 corrupt_frac_t_0p4_0p6=0.2056 acc_corrupt_t_0p6_0p8=0.0906 corrupt_frac_t_0p6_0p8=0.2029 acc_corrupt_t_0p8_1p0=0.1331 corrupt_frac_t_0p8_1p0=0.1938 out_w_norm=12.0630 out_g_norm=1.9269 loss_all=8.1853 init_gold_top10=0.5564 init_gold_top100=0.5713 +step=200 micro_steps=400 elapsed=45.2s lr=6.030000e-05 loss=7.5987 loss_recon=7.5987 loss_meanflow=0.0000 mean_model_t=0.4994 mean_corrupt_t=0.4994 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.1019 corrupt_frac=0.5497 acc_corrupt=0.0654 loss_corrupt=7.5987 wrong_frac=0.5004 init_acc_corrupt=0.4653 acc_corrupt_t_0p0_0p2=0.0351 corrupt_frac_t_0p0_0p2=0.2021 acc_corrupt_t_0p2_0p4=0.0382 corrupt_frac_t_0p2_0p4=0.2008 acc_corrupt_t_0p4_0p6=0.0431 corrupt_frac_t_0p4_0p6=0.1988 acc_corrupt_t_0p6_0p8=0.0782 corrupt_frac_t_0p6_0p8=0.1930 acc_corrupt_t_0p8_1p0=0.1315 corrupt_frac_t_0p8_1p0=0.2053 out_w_norm=20.6401 out_g_norm=1.4235 loss_all=6.7150 init_gold_top10=0.5676 init_gold_top100=0.5845 +step=250 micro_steps=500 elapsed=45.2s lr=7.530000e-05 loss=6.4998 loss_recon=6.4998 loss_meanflow=0.0000 mean_model_t=0.4971 mean_corrupt_t=0.4971 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.2581 corrupt_frac=0.5505 acc_corrupt=0.1707 loss_corrupt=6.4998 wrong_frac=0.5025 init_acc_corrupt=0.4619 acc_corrupt_t_0p0_0p2=0.0434 corrupt_frac_t_0p0_0p2=0.1992 acc_corrupt_t_0p2_0p4=0.0835 corrupt_frac_t_0p2_0p4=0.2055 acc_corrupt_t_0p4_0p6=0.1615 corrupt_frac_t_0p4_0p6=0.1988 acc_corrupt_t_0p6_0p8=0.2559 corrupt_frac_t_0p6_0p8=0.1970 acc_corrupt_t_0p8_1p0=0.3124 corrupt_frac_t_0p8_1p0=0.1995 out_w_norm=28.2191 out_g_norm=0.7069 loss_all=4.8154 init_gold_top10=0.5082 init_gold_top100=0.5321 +step=300 micro_steps=600 elapsed=45.2s lr=9.030000e-05 loss=5.0743 loss_recon=5.0743 loss_meanflow=0.0000 mean_model_t=0.4937 mean_corrupt_t=0.4937 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.5843 corrupt_frac=0.5475 acc_corrupt=0.3939 loss_corrupt=5.0743 wrong_frac=0.5070 init_acc_corrupt=0.4575 acc_corrupt_t_0p0_0p2=0.0557 corrupt_frac_t_0p0_0p2=0.2010 acc_corrupt_t_0p2_0p4=0.2091 corrupt_frac_t_0p2_0p4=0.2080 acc_corrupt_t_0p4_0p6=0.4065 corrupt_frac_t_0p4_0p6=0.2044 acc_corrupt_t_0p6_0p8=0.5808 corrupt_frac_t_0p6_0p8=0.1892 acc_corrupt_t_0p8_1p0=0.7409 corrupt_frac_t_0p8_1p0=0.1974 out_w_norm=34.1727 out_g_norm=0.4427 loss_all=2.4824 init_gold_top10=0.5983 init_gold_top100=0.6241 +step=350 micro_steps=700 elapsed=45.2s lr=1.053000e-04 loss=4.4115 loss_recon=4.4115 loss_meanflow=0.0000 mean_model_t=0.5033 mean_corrupt_t=0.5033 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.6703 corrupt_frac=0.5451 acc_corrupt=0.4571 loss_corrupt=4.4115 wrong_frac=0.4954 init_acc_corrupt=0.4704 acc_corrupt_t_0p0_0p2=0.0577 corrupt_frac_t_0p0_0p2=0.2050 acc_corrupt_t_0p2_0p4=0.2414 corrupt_frac_t_0p2_0p4=0.1920 acc_corrupt_t_0p4_0p6=0.4695 corrupt_frac_t_0p4_0p6=0.1924 acc_corrupt_t_0p6_0p8=0.6594 corrupt_frac_t_0p6_0p8=0.2000 acc_corrupt_t_0p8_1p0=0.8390 corrupt_frac_t_0p8_1p0=0.2106 out_w_norm=39.2515 out_g_norm=0.3457 loss_all=2.7707 init_gold_top10=0.4973 init_gold_top100=0.5261 +step=400 micro_steps=800 elapsed=45.2s lr=1.203000e-04 loss=4.3415 loss_recon=4.3415 loss_meanflow=0.0000 mean_model_t=0.4923 mean_corrupt_t=0.4923 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.6895 corrupt_frac=0.5518 acc_corrupt=0.4626 loss_corrupt=4.3415 wrong_frac=0.5064 init_acc_corrupt=0.4572 acc_corrupt_t_0p0_0p2=0.0594 corrupt_frac_t_0p0_0p2=0.2080 acc_corrupt_t_0p2_0p4=0.2387 corrupt_frac_t_0p2_0p4=0.2025 acc_corrupt_t_0p4_0p6=0.4840 corrupt_frac_t_0p4_0p6=0.2000 acc_corrupt_t_0p6_0p8=0.6845 corrupt_frac_t_0p6_0p8=0.1913 acc_corrupt_t_0p8_1p0=0.8786 corrupt_frac_t_0p8_1p0=0.1982 out_w_norm=42.2768 out_g_norm=0.3825 loss_all=2.5013 init_gold_top10=0.4992 init_gold_top100=0.5279 +step=450 micro_steps=900 elapsed=45.2s lr=1.353000e-04 loss=4.2125 loss_recon=4.2125 loss_meanflow=0.0000 mean_model_t=0.4993 mean_corrupt_t=0.4993 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7077 corrupt_frac=0.5469 acc_corrupt=0.4776 loss_corrupt=4.2125 wrong_frac=0.4993 init_acc_corrupt=0.4658 acc_corrupt_t_0p0_0p2=0.0608 corrupt_frac_t_0p0_0p2=0.1980 acc_corrupt_t_0p2_0p4=0.2490 corrupt_frac_t_0p2_0p4=0.2073 acc_corrupt_t_0p4_0p6=0.4930 corrupt_frac_t_0p4_0p6=0.1941 acc_corrupt_t_0p6_0p8=0.6958 corrupt_frac_t_0p6_0p8=0.2043 acc_corrupt_t_0p8_1p0=0.8969 corrupt_frac_t_0p8_1p0=0.1963 out_w_norm=44.3512 out_g_norm=0.3755 loss_all=2.2941 init_gold_top10=0.5054 init_gold_top100=0.5313 +step=500 micro_steps=1000 elapsed=45.2s lr=1.503000e-04 loss=4.1006 loss_recon=4.1006 loss_meanflow=0.0000 mean_model_t=0.5018 mean_corrupt_t=0.5018 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7131 corrupt_frac=0.5475 acc_corrupt=0.4898 loss_corrupt=4.1006 wrong_frac=0.4942 init_acc_corrupt=0.4723 acc_corrupt_t_0p0_0p2=0.0652 corrupt_frac_t_0p0_0p2=0.1874 acc_corrupt_t_0p2_0p4=0.2541 corrupt_frac_t_0p2_0p4=0.2063 acc_corrupt_t_0p4_0p6=0.4950 corrupt_frac_t_0p4_0p6=0.1993 acc_corrupt_t_0p6_0p8=0.7038 corrupt_frac_t_0p6_0p8=0.2010 acc_corrupt_t_0p8_1p0=0.8964 corrupt_frac_t_0p8_1p0=0.2077 out_w_norm=45.8670 out_g_norm=0.4113 loss_all=2.1710 init_gold_top10=0.5733 init_gold_top100=0.5936 +step=550 micro_steps=1100 elapsed=45.2s lr=1.653000e-04 loss=4.1486 loss_recon=4.1486 loss_meanflow=0.0000 mean_model_t=0.4926 mean_corrupt_t=0.4926 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7040 corrupt_frac=0.5497 acc_corrupt=0.4805 loss_corrupt=4.1486 wrong_frac=0.5111 init_acc_corrupt=0.4533 acc_corrupt_t_0p0_0p2=0.0769 corrupt_frac_t_0p0_0p2=0.2067 acc_corrupt_t_0p2_0p4=0.2740 corrupt_frac_t_0p2_0p4=0.2100 acc_corrupt_t_0p4_0p6=0.5059 corrupt_frac_t_0p4_0p6=0.2059 acc_corrupt_t_0p6_0p8=0.7074 corrupt_frac_t_0p6_0p8=0.1901 acc_corrupt_t_0p8_1p0=0.8994 corrupt_frac_t_0p8_1p0=0.1873 out_w_norm=47.2206 out_g_norm=0.4374 loss_all=2.8802 init_gold_top10=0.4627 init_gold_top100=0.5022 +step=600 micro_steps=1200 elapsed=45.2s lr=1.803000e-04 loss=3.9134 loss_recon=3.9134 loss_meanflow=0.0000 mean_model_t=0.5043 mean_corrupt_t=0.5043 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7205 corrupt_frac=0.5430 acc_corrupt=0.5024 loss_corrupt=3.9134 wrong_frac=0.5013 init_acc_corrupt=0.4640 acc_corrupt_t_0p0_0p2=0.0929 corrupt_frac_t_0p0_0p2=0.2042 acc_corrupt_t_0p2_0p4=0.2875 corrupt_frac_t_0p2_0p4=0.1963 acc_corrupt_t_0p4_0p6=0.5236 corrupt_frac_t_0p4_0p6=0.1977 acc_corrupt_t_0p6_0p8=0.7136 corrupt_frac_t_0p6_0p8=0.2011 acc_corrupt_t_0p8_1p0=0.8966 corrupt_frac_t_0p8_1p0=0.2007 out_w_norm=48.5583 out_g_norm=0.4844 loss_all=2.0677 init_gold_top10=0.5153 init_gold_top100=0.5461 +step=650 micro_steps=1300 elapsed=45.2s lr=1.953000e-04 loss=3.6737 loss_recon=3.6737 loss_meanflow=0.0000 mean_model_t=0.5050 mean_corrupt_t=0.5050 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7298 corrupt_frac=0.5461 acc_corrupt=0.5218 loss_corrupt=3.6737 wrong_frac=0.4986 init_acc_corrupt=0.4667 acc_corrupt_t_0p0_0p2=0.1166 corrupt_frac_t_0p0_0p2=0.1987 acc_corrupt_t_0p2_0p4=0.3096 corrupt_frac_t_0p2_0p4=0.2005 acc_corrupt_t_0p4_0p6=0.5398 corrupt_frac_t_0p4_0p6=0.2013 acc_corrupt_t_0p6_0p8=0.7239 corrupt_frac_t_0p6_0p8=0.1938 acc_corrupt_t_0p8_1p0=0.9088 corrupt_frac_t_0p8_1p0=0.2076 out_w_norm=50.0362 out_g_norm=0.5514 loss_all=2.1810 init_gold_top10=0.4869 init_gold_top100=0.5132 +step=700 micro_steps=1400 elapsed=45.2s lr=2.103000e-04 loss=3.5027 loss_recon=3.5027 loss_meanflow=0.0000 mean_model_t=0.5036 mean_corrupt_t=0.5036 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7370 corrupt_frac=0.5501 acc_corrupt=0.5359 loss_corrupt=3.5027 wrong_frac=0.5002 init_acc_corrupt=0.4656 acc_corrupt_t_0p0_0p2=0.1336 corrupt_frac_t_0p0_0p2=0.2035 acc_corrupt_t_0p2_0p4=0.3305 corrupt_frac_t_0p2_0p4=0.1926 acc_corrupt_t_0p4_0p6=0.5592 corrupt_frac_t_0p4_0p6=0.2024 acc_corrupt_t_0p6_0p8=0.7391 corrupt_frac_t_0p6_0p8=0.1995 acc_corrupt_t_0p8_1p0=0.9131 corrupt_frac_t_0p8_1p0=0.2042 out_w_norm=51.4271 out_g_norm=0.5445 loss_all=2.5369 init_gold_top10=0.3625 init_gold_top100=0.4098 +step=750 micro_steps=1500 elapsed=45.2s lr=2.253000e-04 loss=3.3778 loss_recon=3.3778 loss_meanflow=0.0000 mean_model_t=0.5020 mean_corrupt_t=0.5020 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7433 corrupt_frac=0.5484 acc_corrupt=0.5449 loss_corrupt=3.3778 wrong_frac=0.5002 init_acc_corrupt=0.4641 acc_corrupt_t_0p0_0p2=0.1444 corrupt_frac_t_0p0_0p2=0.1866 acc_corrupt_t_0p2_0p4=0.3405 corrupt_frac_t_0p2_0p4=0.2208 acc_corrupt_t_0p4_0p6=0.5743 corrupt_frac_t_0p4_0p6=0.1995 acc_corrupt_t_0p6_0p8=0.7515 corrupt_frac_t_0p6_0p8=0.1923 acc_corrupt_t_0p8_1p0=0.9148 corrupt_frac_t_0p8_1p0=0.2008 out_w_norm=52.7795 out_g_norm=0.5540 loss_all=2.0706 init_gold_top10=0.4616 init_gold_top100=0.4928 +step=800 micro_steps=1600 elapsed=45.2s lr=2.403000e-04 loss=3.2915 loss_recon=3.2915 loss_meanflow=0.0000 mean_model_t=0.5003 mean_corrupt_t=0.5003 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7452 corrupt_frac=0.5529 acc_corrupt=0.5512 loss_corrupt=3.2915 wrong_frac=0.4990 init_acc_corrupt=0.4661 acc_corrupt_t_0p0_0p2=0.1487 corrupt_frac_t_0p0_0p2=0.2034 acc_corrupt_t_0p2_0p4=0.3483 corrupt_frac_t_0p2_0p4=0.2007 acc_corrupt_t_0p4_0p6=0.5813 corrupt_frac_t_0p4_0p6=0.1906 acc_corrupt_t_0p6_0p8=0.7570 corrupt_frac_t_0p6_0p8=0.1994 acc_corrupt_t_0p8_1p0=0.9195 corrupt_frac_t_0p8_1p0=0.2058 out_w_norm=54.1221 out_g_norm=0.5276 loss_all=1.6476 init_gold_top10=0.4991 init_gold_top100=0.5325 +step=850 micro_steps=1700 elapsed=45.3s lr=2.553000e-04 loss=3.2020 loss_recon=3.2020 loss_meanflow=0.0000 mean_model_t=0.5049 mean_corrupt_t=0.5049 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7509 corrupt_frac=0.5489 acc_corrupt=0.5582 loss_corrupt=3.2020 wrong_frac=0.4976 init_acc_corrupt=0.4673 acc_corrupt_t_0p0_0p2=0.1548 corrupt_frac_t_0p0_0p2=0.1980 acc_corrupt_t_0p2_0p4=0.3583 corrupt_frac_t_0p2_0p4=0.2036 acc_corrupt_t_0p4_0p6=0.5892 corrupt_frac_t_0p4_0p6=0.1997 acc_corrupt_t_0p6_0p8=0.7539 corrupt_frac_t_0p6_0p8=0.1794 acc_corrupt_t_0p8_1p0=0.9182 corrupt_frac_t_0p8_1p0=0.2210 out_w_norm=55.4985 out_g_norm=0.5298 loss_all=1.9441 init_gold_top10=0.4448 init_gold_top100=0.4827 +step=900 micro_steps=1800 elapsed=45.3s lr=2.703000e-04 loss=3.1486 loss_recon=3.1486 loss_meanflow=0.0000 mean_model_t=0.5047 mean_corrupt_t=0.5047 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7518 corrupt_frac=0.5514 acc_corrupt=0.5620 loss_corrupt=3.1486 wrong_frac=0.4975 init_acc_corrupt=0.4681 acc_corrupt_t_0p0_0p2=0.1628 corrupt_frac_t_0p0_0p2=0.1991 acc_corrupt_t_0p2_0p4=0.3595 corrupt_frac_t_0p2_0p4=0.1978 acc_corrupt_t_0p4_0p6=0.5886 corrupt_frac_t_0p4_0p6=0.2030 acc_corrupt_t_0p6_0p8=0.7606 corrupt_frac_t_0p6_0p8=0.1962 acc_corrupt_t_0p8_1p0=0.9226 corrupt_frac_t_0p8_1p0=0.2060 out_w_norm=56.8676 out_g_norm=0.4912 loss_all=1.5493 init_gold_top10=0.5471 init_gold_top100=0.5649 +step=950 micro_steps=1900 elapsed=45.2s lr=2.853000e-04 loss=3.1040 loss_recon=3.1040 loss_meanflow=0.0000 mean_model_t=0.5006 mean_corrupt_t=0.5006 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7535 corrupt_frac=0.5511 acc_corrupt=0.5649 loss_corrupt=3.1040 wrong_frac=0.5000 init_acc_corrupt=0.4657 acc_corrupt_t_0p0_0p2=0.1712 corrupt_frac_t_0p0_0p2=0.2061 acc_corrupt_t_0p2_0p4=0.3787 corrupt_frac_t_0p2_0p4=0.1974 acc_corrupt_t_0p4_0p6=0.5961 corrupt_frac_t_0p4_0p6=0.1943 acc_corrupt_t_0p6_0p8=0.7619 corrupt_frac_t_0p6_0p8=0.2029 acc_corrupt_t_0p8_1p0=0.9254 corrupt_frac_t_0p8_1p0=0.1992 out_w_norm=58.2690 out_g_norm=0.4832 loss_all=1.8259 init_gold_top10=0.5160 init_gold_top100=0.5457 +step=1000 micro_steps=2000 elapsed=45.2s lr=3.003000e-04 loss=3.0978 loss_recon=3.0978 loss_meanflow=0.0000 mean_model_t=0.4944 mean_corrupt_t=0.4944 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7514 corrupt_frac=0.5527 acc_corrupt=0.5629 loss_corrupt=3.0978 wrong_frac=0.5042 init_acc_corrupt=0.4604 acc_corrupt_t_0p0_0p2=0.1739 corrupt_frac_t_0p0_0p2=0.1970 acc_corrupt_t_0p2_0p4=0.3727 corrupt_frac_t_0p2_0p4=0.2128 acc_corrupt_t_0p4_0p6=0.5932 corrupt_frac_t_0p4_0p6=0.1863 acc_corrupt_t_0p6_0p8=0.7616 corrupt_frac_t_0p6_0p8=0.2083 acc_corrupt_t_0p8_1p0=0.9191 corrupt_frac_t_0p8_1p0=0.1981 out_w_norm=59.7138 out_g_norm=0.4876 loss_all=1.7552 init_gold_top10=0.5200 init_gold_top100=0.5453 +step=1050 micro_steps=2100 elapsed=46.9s lr=3.153000e-04 loss=2.9530 loss_recon=2.9530 loss_meanflow=0.0000 mean_model_t=0.5073 mean_corrupt_t=0.5073 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7625 corrupt_frac=0.5483 acc_corrupt=0.5791 loss_corrupt=2.9530 wrong_frac=0.4911 init_acc_corrupt=0.4758 acc_corrupt_t_0p0_0p2=0.1757 corrupt_frac_t_0p0_0p2=0.1926 acc_corrupt_t_0p2_0p4=0.3863 corrupt_frac_t_0p2_0p4=0.1960 acc_corrupt_t_0p4_0p6=0.6057 corrupt_frac_t_0p4_0p6=0.2038 acc_corrupt_t_0p6_0p8=0.7698 corrupt_frac_t_0p6_0p8=0.2081 acc_corrupt_t_0p8_1p0=0.9271 corrupt_frac_t_0p8_1p0=0.2036 out_w_norm=61.2066 out_g_norm=0.4633 loss_all=1.2946 init_gold_top10=0.5825 init_gold_top100=0.6036 +step=1100 micro_steps=2200 elapsed=45.2s lr=3.303000e-04 loss=2.9167 loss_recon=2.9167 loss_meanflow=0.0000 mean_model_t=0.5058 mean_corrupt_t=0.5058 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7640 corrupt_frac=0.5487 acc_corrupt=0.5825 loss_corrupt=2.9167 wrong_frac=0.4909 init_acc_corrupt=0.4759 acc_corrupt_t_0p0_0p2=0.1829 corrupt_frac_t_0p0_0p2=0.1867 acc_corrupt_t_0p2_0p4=0.3855 corrupt_frac_t_0p2_0p4=0.1991 acc_corrupt_t_0p4_0p6=0.6043 corrupt_frac_t_0p4_0p6=0.2035 acc_corrupt_t_0p6_0p8=0.7698 corrupt_frac_t_0p6_0p8=0.2049 acc_corrupt_t_0p8_1p0=0.9273 corrupt_frac_t_0p8_1p0=0.2058 out_w_norm=62.7455 out_g_norm=0.4615 loss_all=1.8148 init_gold_top10=0.4987 init_gold_top100=0.5238 +step=1150 micro_steps=2300 elapsed=45.2s lr=3.453000e-04 loss=2.9463 loss_recon=2.9463 loss_meanflow=0.0000 mean_model_t=0.5025 mean_corrupt_t=0.5025 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7600 corrupt_frac=0.5502 acc_corrupt=0.5763 loss_corrupt=2.9463 wrong_frac=0.4997 init_acc_corrupt=0.4659 acc_corrupt_t_0p0_0p2=0.1858 corrupt_frac_t_0p0_0p2=0.1950 acc_corrupt_t_0p2_0p4=0.3826 corrupt_frac_t_0p2_0p4=0.2036 acc_corrupt_t_0p4_0p6=0.6028 corrupt_frac_t_0p4_0p6=0.1899 acc_corrupt_t_0p6_0p8=0.7695 corrupt_frac_t_0p6_0p8=0.2143 acc_corrupt_t_0p8_1p0=0.9253 corrupt_frac_t_0p8_1p0=0.1994 out_w_norm=64.3157 out_g_norm=0.4428 loss_all=1.9670 init_gold_top10=0.4410 init_gold_top100=0.4779 +step=1200 micro_steps=2400 elapsed=45.2s lr=3.603000e-04 loss=2.9624 loss_recon=2.9624 loss_meanflow=0.0000 mean_model_t=0.5025 mean_corrupt_t=0.5025 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7593 corrupt_frac=0.5482 acc_corrupt=0.5736 loss_corrupt=2.9624 wrong_frac=0.5034 init_acc_corrupt=0.4620 acc_corrupt_t_0p0_0p2=0.1874 corrupt_frac_t_0p0_0p2=0.2051 acc_corrupt_t_0p2_0p4=0.3909 corrupt_frac_t_0p2_0p4=0.2005 acc_corrupt_t_0p4_0p6=0.6065 corrupt_frac_t_0p4_0p6=0.2013 acc_corrupt_t_0p6_0p8=0.7759 corrupt_frac_t_0p6_0p8=0.1908 acc_corrupt_t_0p8_1p0=0.9230 corrupt_frac_t_0p8_1p0=0.2041 out_w_norm=65.9681 out_g_norm=0.4317 loss_all=1.6550 init_gold_top10=0.4590 init_gold_top100=0.4970 +step=1250 micro_steps=2500 elapsed=45.2s lr=3.753000e-04 loss=2.8712 loss_recon=2.8712 loss_meanflow=0.0000 mean_model_t=0.5036 mean_corrupt_t=0.5036 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7636 corrupt_frac=0.5503 acc_corrupt=0.5835 loss_corrupt=2.8712 wrong_frac=0.4962 init_acc_corrupt=0.4701 acc_corrupt_t_0p0_0p2=0.1886 corrupt_frac_t_0p0_0p2=0.1951 acc_corrupt_t_0p2_0p4=0.3999 corrupt_frac_t_0p2_0p4=0.2017 acc_corrupt_t_0p4_0p6=0.6119 corrupt_frac_t_0p4_0p6=0.1992 acc_corrupt_t_0p6_0p8=0.7766 corrupt_frac_t_0p6_0p8=0.2000 acc_corrupt_t_0p8_1p0=0.9258 corrupt_frac_t_0p8_1p0=0.2040 out_w_norm=67.6794 out_g_norm=0.4220 loss_all=1.7963 init_gold_top10=0.5088 init_gold_top100=0.5340 +step=1300 micro_steps=2600 elapsed=45.2s lr=3.903000e-04 loss=2.8643 loss_recon=2.8643 loss_meanflow=0.0000 mean_model_t=0.5036 mean_corrupt_t=0.5036 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7643 corrupt_frac=0.5479 acc_corrupt=0.5835 loss_corrupt=2.8643 wrong_frac=0.4984 init_acc_corrupt=0.4684 acc_corrupt_t_0p0_0p2=0.1874 corrupt_frac_t_0p0_0p2=0.1970 acc_corrupt_t_0p2_0p4=0.3993 corrupt_frac_t_0p2_0p4=0.1955 acc_corrupt_t_0p4_0p6=0.6166 corrupt_frac_t_0p4_0p6=0.2085 acc_corrupt_t_0p6_0p8=0.7802 corrupt_frac_t_0p6_0p8=0.2034 acc_corrupt_t_0p8_1p0=0.9270 corrupt_frac_t_0p8_1p0=0.1955 out_w_norm=69.4536 out_g_norm=0.4111 loss_all=1.4161 init_gold_top10=0.5686 init_gold_top100=0.5899 +step=1350 micro_steps=2700 elapsed=45.2s lr=4.053000e-04 loss=2.8566 loss_recon=2.8566 loss_meanflow=0.0000 mean_model_t=0.4993 mean_corrupt_t=0.4993 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7638 corrupt_frac=0.5461 acc_corrupt=0.5822 loss_corrupt=2.8566 wrong_frac=0.5025 init_acc_corrupt=0.4621 acc_corrupt_t_0p0_0p2=0.1957 corrupt_frac_t_0p0_0p2=0.1973 acc_corrupt_t_0p2_0p4=0.4017 corrupt_frac_t_0p2_0p4=0.2065 acc_corrupt_t_0p4_0p6=0.6160 corrupt_frac_t_0p4_0p6=0.1996 acc_corrupt_t_0p6_0p8=0.7749 corrupt_frac_t_0p6_0p8=0.2033 acc_corrupt_t_0p8_1p0=0.9318 corrupt_frac_t_0p8_1p0=0.1934 out_w_norm=71.2713 out_g_norm=0.3845 loss_all=1.7571 init_gold_top10=0.5019 init_gold_top100=0.5373 +step=1400 micro_steps=2800 elapsed=45.2s lr=4.203000e-04 loss=2.8122 loss_recon=2.8122 loss_meanflow=0.0000 mean_model_t=0.5073 mean_corrupt_t=0.5073 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7619 corrupt_frac=0.5589 acc_corrupt=0.5882 loss_corrupt=2.8122 wrong_frac=0.4948 init_acc_corrupt=0.4720 acc_corrupt_t_0p0_0p2=0.1925 corrupt_frac_t_0p0_0p2=0.1993 acc_corrupt_t_0p2_0p4=0.3981 corrupt_frac_t_0p2_0p4=0.1853 acc_corrupt_t_0p4_0p6=0.6168 corrupt_frac_t_0p4_0p6=0.2111 acc_corrupt_t_0p6_0p8=0.7783 corrupt_frac_t_0p6_0p8=0.1938 acc_corrupt_t_0p8_1p0=0.9268 corrupt_frac_t_0p8_1p0=0.2105 out_w_norm=73.1168 out_g_norm=0.3780 loss_all=1.4098 init_gold_top10=0.4837 init_gold_top100=0.5067 +step=1450 micro_steps=2900 elapsed=45.2s lr=4.353000e-04 loss=2.8179 loss_recon=2.8179 loss_meanflow=0.0000 mean_model_t=0.5003 mean_corrupt_t=0.5003 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7642 corrupt_frac=0.5496 acc_corrupt=0.5865 loss_corrupt=2.8179 wrong_frac=0.4999 init_acc_corrupt=0.4651 acc_corrupt_t_0p0_0p2=0.1991 corrupt_frac_t_0p0_0p2=0.1954 acc_corrupt_t_0p2_0p4=0.3989 corrupt_frac_t_0p2_0p4=0.1985 acc_corrupt_t_0p4_0p6=0.6191 corrupt_frac_t_0p4_0p6=0.2098 acc_corrupt_t_0p6_0p8=0.7799 corrupt_frac_t_0p6_0p8=0.2028 acc_corrupt_t_0p8_1p0=0.9322 corrupt_frac_t_0p8_1p0=0.1934 out_w_norm=74.9763 out_g_norm=0.3696 loss_all=1.4470 init_gold_top10=0.5154 init_gold_top100=0.5427 +step=1500 micro_steps=3000 elapsed=45.2s lr=4.503000e-04 loss=2.7367 loss_recon=2.7367 loss_meanflow=0.0000 mean_model_t=0.5097 mean_corrupt_t=0.5097 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7663 corrupt_frac=0.5597 acc_corrupt=0.5968 loss_corrupt=2.7367 wrong_frac=0.4899 init_acc_corrupt=0.4768 acc_corrupt_t_0p0_0p2=0.2009 corrupt_frac_t_0p0_0p2=0.1976 acc_corrupt_t_0p2_0p4=0.4088 corrupt_frac_t_0p2_0p4=0.1875 acc_corrupt_t_0p4_0p6=0.6210 corrupt_frac_t_0p4_0p6=0.1947 acc_corrupt_t_0p6_0p8=0.7827 corrupt_frac_t_0p6_0p8=0.2137 acc_corrupt_t_0p8_1p0=0.9312 corrupt_frac_t_0p8_1p0=0.2064 out_w_norm=76.9150 out_g_norm=0.3511 loss_all=1.8447 init_gold_top10=0.4210 init_gold_top100=0.4522 +step=1550 micro_steps=3100 elapsed=45.2s lr=4.653000e-04 loss=2.7844 loss_recon=2.7844 loss_meanflow=0.0000 mean_model_t=0.4998 mean_corrupt_t=0.4998 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7664 corrupt_frac=0.5459 acc_corrupt=0.5890 loss_corrupt=2.7844 wrong_frac=0.5007 init_acc_corrupt=0.4641 acc_corrupt_t_0p0_0p2=0.2034 corrupt_frac_t_0p0_0p2=0.2008 acc_corrupt_t_0p2_0p4=0.4112 corrupt_frac_t_0p2_0p4=0.2063 acc_corrupt_t_0p4_0p6=0.6241 corrupt_frac_t_0p4_0p6=0.1910 acc_corrupt_t_0p6_0p8=0.7863 corrupt_frac_t_0p6_0p8=0.2063 acc_corrupt_t_0p8_1p0=0.9300 corrupt_frac_t_0p8_1p0=0.1956 out_w_norm=78.8632 out_g_norm=0.3418 loss_all=2.3159 init_gold_top10=0.3711 init_gold_top100=0.4217 +step=1600 micro_steps=3200 elapsed=45.2s lr=4.803000e-04 loss=2.7282 loss_recon=2.7282 loss_meanflow=0.0000 mean_model_t=0.5006 mean_corrupt_t=0.5006 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7677 corrupt_frac=0.5514 acc_corrupt=0.5958 loss_corrupt=2.7282 wrong_frac=0.4951 init_acc_corrupt=0.4700 acc_corrupt_t_0p0_0p2=0.2097 corrupt_frac_t_0p0_0p2=0.1979 acc_corrupt_t_0p2_0p4=0.4097 corrupt_frac_t_0p2_0p4=0.2027 acc_corrupt_t_0p4_0p6=0.6266 corrupt_frac_t_0p4_0p6=0.1938 acc_corrupt_t_0p6_0p8=0.7869 corrupt_frac_t_0p6_0p8=0.1993 acc_corrupt_t_0p8_1p0=0.9330 corrupt_frac_t_0p8_1p0=0.2081 out_w_norm=80.9097 out_g_norm=0.3277 loss_all=1.1798 init_gold_top10=0.5919 init_gold_top100=0.6168 +step=1650 micro_steps=3300 elapsed=45.2s lr=4.953000e-04 loss=2.7302 loss_recon=2.7302 loss_meanflow=0.0000 mean_model_t=0.4945 mean_corrupt_t=0.4945 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7726 corrupt_frac=0.5342 acc_corrupt=0.5932 loss_corrupt=2.7302 wrong_frac=0.5023 init_acc_corrupt=0.4626 acc_corrupt_t_0p0_0p2=0.2142 corrupt_frac_t_0p0_0p2=0.2020 acc_corrupt_t_0p2_0p4=0.4168 corrupt_frac_t_0p2_0p4=0.2028 acc_corrupt_t_0p4_0p6=0.6265 corrupt_frac_t_0p4_0p6=0.1999 acc_corrupt_t_0p6_0p8=0.7873 corrupt_frac_t_0p6_0p8=0.1952 acc_corrupt_t_0p8_1p0=0.9316 corrupt_frac_t_0p8_1p0=0.2002 out_w_norm=82.9388 out_g_norm=0.3260 loss_all=1.5407 init_gold_top10=0.4965 init_gold_top100=0.5261 +step=1700 micro_steps=3400 elapsed=45.2s lr=5.103000e-04 loss=2.6745 loss_recon=2.6745 loss_meanflow=0.0000 mean_model_t=0.5089 mean_corrupt_t=0.5089 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7732 corrupt_frac=0.5475 acc_corrupt=0.6030 loss_corrupt=2.6745 wrong_frac=0.4895 init_acc_corrupt=0.4772 acc_corrupt_t_0p0_0p2=0.2019 corrupt_frac_t_0p0_0p2=0.1946 acc_corrupt_t_0p2_0p4=0.4210 corrupt_frac_t_0p2_0p4=0.1900 acc_corrupt_t_0p4_0p6=0.6220 corrupt_frac_t_0p4_0p6=0.1951 acc_corrupt_t_0p6_0p8=0.7882 corrupt_frac_t_0p6_0p8=0.2071 acc_corrupt_t_0p8_1p0=0.9339 corrupt_frac_t_0p8_1p0=0.2133 out_w_norm=85.0097 out_g_norm=0.3063 loss_all=1.5621 init_gold_top10=0.4760 init_gold_top100=0.5094 +step=1750 micro_steps=3500 elapsed=45.2s lr=5.253000e-04 loss=2.6796 loss_recon=2.6796 loss_meanflow=0.0000 mean_model_t=0.5029 mean_corrupt_t=0.5029 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7728 corrupt_frac=0.5446 acc_corrupt=0.6002 loss_corrupt=2.6796 wrong_frac=0.4950 init_acc_corrupt=0.4716 acc_corrupt_t_0p0_0p2=0.2091 corrupt_frac_t_0p0_0p2=0.2038 acc_corrupt_t_0p2_0p4=0.4173 corrupt_frac_t_0p2_0p4=0.1808 acc_corrupt_t_0p4_0p6=0.6281 corrupt_frac_t_0p4_0p6=0.2078 acc_corrupt_t_0p6_0p8=0.7914 corrupt_frac_t_0p6_0p8=0.2014 acc_corrupt_t_0p8_1p0=0.9324 corrupt_frac_t_0p8_1p0=0.2062 out_w_norm=87.1179 out_g_norm=0.3183 loss_all=1.5896 init_gold_top10=0.4937 init_gold_top100=0.5215 +step=1800 micro_steps=3600 elapsed=45.2s lr=5.403000e-04 loss=2.7685 loss_recon=2.7685 loss_meanflow=0.0000 mean_model_t=0.4926 mean_corrupt_t=0.4926 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7638 corrupt_frac=0.5473 acc_corrupt=0.5875 loss_corrupt=2.7685 wrong_frac=0.5090 init_acc_corrupt=0.4556 acc_corrupt_t_0p0_0p2=0.2089 corrupt_frac_t_0p0_0p2=0.2155 acc_corrupt_t_0p2_0p4=0.4173 corrupt_frac_t_0p2_0p4=0.1976 acc_corrupt_t_0p4_0p6=0.6341 corrupt_frac_t_0p4_0p6=0.2041 acc_corrupt_t_0p6_0p8=0.7926 corrupt_frac_t_0p6_0p8=0.1918 acc_corrupt_t_0p8_1p0=0.9352 corrupt_frac_t_0p8_1p0=0.1909 out_w_norm=89.2195 out_g_norm=0.2982 loss_all=1.4890 init_gold_top10=0.5165 init_gold_top100=0.5417 +step=1850 micro_steps=3700 elapsed=45.2s lr=5.553000e-04 loss=2.7137 loss_recon=2.7137 loss_meanflow=0.0000 mean_model_t=0.4928 mean_corrupt_t=0.4928 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7689 corrupt_frac=0.5425 acc_corrupt=0.5932 loss_corrupt=2.7137 wrong_frac=0.5074 init_acc_corrupt=0.4567 acc_corrupt_t_0p0_0p2=0.2203 corrupt_frac_t_0p0_0p2=0.2012 acc_corrupt_t_0p2_0p4=0.4171 corrupt_frac_t_0p2_0p4=0.2075 acc_corrupt_t_0p4_0p6=0.6315 corrupt_frac_t_0p4_0p6=0.2083 acc_corrupt_t_0p6_0p8=0.7941 corrupt_frac_t_0p6_0p8=0.1903 acc_corrupt_t_0p8_1p0=0.9324 corrupt_frac_t_0p8_1p0=0.1927 out_w_norm=91.3772 out_g_norm=0.2987 loss_all=1.1057 init_gold_top10=0.5772 init_gold_top100=0.6048 +step=1900 micro_steps=3800 elapsed=45.2s lr=5.703000e-04 loss=2.7134 loss_recon=2.7134 loss_meanflow=0.0000 mean_model_t=0.4990 mean_corrupt_t=0.4990 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7669 corrupt_frac=0.5498 acc_corrupt=0.5942 loss_corrupt=2.7134 wrong_frac=0.5029 init_acc_corrupt=0.4620 acc_corrupt_t_0p0_0p2=0.2092 corrupt_frac_t_0p0_0p2=0.2081 acc_corrupt_t_0p2_0p4=0.4179 corrupt_frac_t_0p2_0p4=0.1978 acc_corrupt_t_0p4_0p6=0.6304 corrupt_frac_t_0p4_0p6=0.1887 acc_corrupt_t_0p6_0p8=0.7933 corrupt_frac_t_0p6_0p8=0.2090 acc_corrupt_t_0p8_1p0=0.9332 corrupt_frac_t_0p8_1p0=0.1964 out_w_norm=93.5658 out_g_norm=0.2956 loss_all=1.4767 init_gold_top10=0.4732 init_gold_top100=0.5109 +step=1950 micro_steps=3900 elapsed=45.2s lr=5.853000e-04 loss=2.6656 loss_recon=2.6656 loss_meanflow=0.0000 mean_model_t=0.5017 mean_corrupt_t=0.5017 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7707 corrupt_frac=0.5513 acc_corrupt=0.6015 loss_corrupt=2.6656 wrong_frac=0.4982 init_acc_corrupt=0.4688 acc_corrupt_t_0p0_0p2=0.2123 corrupt_frac_t_0p0_0p2=0.2044 acc_corrupt_t_0p2_0p4=0.4264 corrupt_frac_t_0p2_0p4=0.1906 acc_corrupt_t_0p4_0p6=0.6359 corrupt_frac_t_0p4_0p6=0.2038 acc_corrupt_t_0p6_0p8=0.7966 corrupt_frac_t_0p6_0p8=0.1970 acc_corrupt_t_0p8_1p0=0.9319 corrupt_frac_t_0p8_1p0=0.2043 out_w_norm=95.7508 out_g_norm=0.2763 loss_all=1.4483 init_gold_top10=0.5029 init_gold_top100=0.5286 +step=2000 micro_steps=4000 elapsed=45.2s lr=6.000000e-04 loss=2.6107 loss_recon=2.6107 loss_meanflow=0.0000 mean_model_t=0.5034 mean_corrupt_t=0.5034 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7732 corrupt_frac=0.5542 acc_corrupt=0.6065 loss_corrupt=2.6107 wrong_frac=0.4957 init_acc_corrupt=0.4717 acc_corrupt_t_0p0_0p2=0.2153 corrupt_frac_t_0p0_0p2=0.1909 acc_corrupt_t_0p2_0p4=0.4214 corrupt_frac_t_0p2_0p4=0.1919 acc_corrupt_t_0p4_0p6=0.6326 corrupt_frac_t_0p4_0p6=0.2197 acc_corrupt_t_0p6_0p8=0.7956 corrupt_frac_t_0p6_0p8=0.1998 acc_corrupt_t_0p8_1p0=0.9354 corrupt_frac_t_0p8_1p0=0.2001 out_w_norm=97.9652 out_g_norm=0.2667 loss_all=1.2884 init_gold_top10=0.5167 init_gold_top100=0.5498 +step=2050 micro_steps=4100 elapsed=47.0s lr=6.000000e-04 loss=2.5762 loss_recon=2.5762 loss_meanflow=0.0000 mean_model_t=0.5086 mean_corrupt_t=0.5086 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7784 corrupt_frac=0.5499 acc_corrupt=0.6116 loss_corrupt=2.5762 wrong_frac=0.4908 init_acc_corrupt=0.4764 acc_corrupt_t_0p0_0p2=0.2162 corrupt_frac_t_0p0_0p2=0.1955 acc_corrupt_t_0p2_0p4=0.4269 corrupt_frac_t_0p2_0p4=0.1939 acc_corrupt_t_0p4_0p6=0.6385 corrupt_frac_t_0p4_0p6=0.1909 acc_corrupt_t_0p6_0p8=0.8011 corrupt_frac_t_0p6_0p8=0.2037 acc_corrupt_t_0p8_1p0=0.9329 corrupt_frac_t_0p8_1p0=0.2160 out_w_norm=100.1782 out_g_norm=0.2628 loss_all=1.8679 init_gold_top10=0.4253 init_gold_top100=0.4692 +step=2100 micro_steps=4200 elapsed=45.2s lr=6.000000e-04 loss=2.6098 loss_recon=2.6098 loss_meanflow=0.0000 mean_model_t=0.4995 mean_corrupt_t=0.4995 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7744 corrupt_frac=0.5558 acc_corrupt=0.6073 loss_corrupt=2.6098 wrong_frac=0.4987 init_acc_corrupt=0.4667 acc_corrupt_t_0p0_0p2=0.2188 corrupt_frac_t_0p0_0p2=0.1980 acc_corrupt_t_0p2_0p4=0.4351 corrupt_frac_t_0p2_0p4=0.2037 acc_corrupt_t_0p4_0p6=0.6436 corrupt_frac_t_0p4_0p6=0.1990 acc_corrupt_t_0p6_0p8=0.8000 corrupt_frac_t_0p6_0p8=0.1974 acc_corrupt_t_0p8_1p0=0.9376 corrupt_frac_t_0p8_1p0=0.2020 out_w_norm=102.3344 out_g_norm=0.2531 loss_all=1.7028 init_gold_top10=0.4853 init_gold_top100=0.5165 +step=2150 micro_steps=4300 elapsed=45.3s lr=6.000000e-04 loss=2.6023 loss_recon=2.6023 loss_meanflow=0.0000 mean_model_t=0.4977 mean_corrupt_t=0.4977 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7771 corrupt_frac=0.5486 acc_corrupt=0.6067 loss_corrupt=2.6023 wrong_frac=0.5040 init_acc_corrupt=0.4611 acc_corrupt_t_0p0_0p2=0.2273 corrupt_frac_t_0p0_0p2=0.1996 acc_corrupt_t_0p2_0p4=0.4319 corrupt_frac_t_0p2_0p4=0.2040 acc_corrupt_t_0p4_0p6=0.6465 corrupt_frac_t_0p4_0p6=0.2051 acc_corrupt_t_0p6_0p8=0.8000 corrupt_frac_t_0p6_0p8=0.1906 acc_corrupt_t_0p8_1p0=0.9366 corrupt_frac_t_0p8_1p0=0.2024 out_w_norm=104.3803 out_g_norm=0.2451 loss_all=1.1388 init_gold_top10=0.5874 init_gold_top100=0.6025 +step=2200 micro_steps=4400 elapsed=45.2s lr=5.999999e-04 loss=2.4671 loss_recon=2.4671 loss_meanflow=0.0000 mean_model_t=0.5105 mean_corrupt_t=0.5105 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7858 corrupt_frac=0.5540 acc_corrupt=0.6253 loss_corrupt=2.4671 wrong_frac=0.4850 init_acc_corrupt=0.4830 acc_corrupt_t_0p0_0p2=0.2252 corrupt_frac_t_0p0_0p2=0.1876 acc_corrupt_t_0p2_0p4=0.4434 corrupt_frac_t_0p2_0p4=0.1927 acc_corrupt_t_0p4_0p6=0.6473 corrupt_frac_t_0p4_0p6=0.2035 acc_corrupt_t_0p6_0p8=0.8090 corrupt_frac_t_0p6_0p8=0.1974 acc_corrupt_t_0p8_1p0=0.9392 corrupt_frac_t_0p8_1p0=0.2211 out_w_norm=106.3015 out_g_norm=0.2363 loss_all=1.4615 init_gold_top10=0.4870 init_gold_top100=0.5130 +step=2250 micro_steps=4500 elapsed=45.2s lr=5.999999e-04 loss=2.5269 loss_recon=2.5269 loss_meanflow=0.0000 mean_model_t=0.4983 mean_corrupt_t=0.4983 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7819 corrupt_frac=0.5510 acc_corrupt=0.6162 loss_corrupt=2.5269 wrong_frac=0.4992 init_acc_corrupt=0.4670 acc_corrupt_t_0p0_0p2=0.2338 corrupt_frac_t_0p0_0p2=0.1953 acc_corrupt_t_0p2_0p4=0.4396 corrupt_frac_t_0p2_0p4=0.1947 acc_corrupt_t_0p4_0p6=0.6511 corrupt_frac_t_0p4_0p6=0.2135 acc_corrupt_t_0p6_0p8=0.8075 corrupt_frac_t_0p6_0p8=0.2024 acc_corrupt_t_0p8_1p0=0.9400 corrupt_frac_t_0p8_1p0=0.1941 out_w_norm=108.1318 out_g_norm=0.2306 loss_all=1.9605 init_gold_top10=0.4190 init_gold_top100=0.4671 +step=2300 micro_steps=4600 elapsed=45.2s lr=5.999999e-04 loss=2.5752 loss_recon=2.5752 loss_meanflow=0.0000 mean_model_t=0.4960 mean_corrupt_t=0.4960 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7784 corrupt_frac=0.5495 acc_corrupt=0.6090 loss_corrupt=2.5752 wrong_frac=0.5057 init_acc_corrupt=0.4588 acc_corrupt_t_0p0_0p2=0.2302 corrupt_frac_t_0p0_0p2=0.2112 acc_corrupt_t_0p2_0p4=0.4390 corrupt_frac_t_0p2_0p4=0.1975 acc_corrupt_t_0p4_0p6=0.6525 corrupt_frac_t_0p4_0p6=0.1960 acc_corrupt_t_0p6_0p8=0.8067 corrupt_frac_t_0p6_0p8=0.1991 acc_corrupt_t_0p8_1p0=0.9394 corrupt_frac_t_0p8_1p0=0.1980 out_w_norm=109.8613 out_g_norm=0.2273 loss_all=1.2937 init_gold_top10=0.5139 init_gold_top100=0.5384 +step=2350 micro_steps=4700 elapsed=45.2s lr=5.999998e-04 loss=2.5579 loss_recon=2.5579 loss_meanflow=0.0000 mean_model_t=0.4913 mean_corrupt_t=0.4913 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7807 corrupt_frac=0.5465 acc_corrupt=0.6111 loss_corrupt=2.5579 wrong_frac=0.5077 init_acc_corrupt=0.4569 acc_corrupt_t_0p0_0p2=0.2275 corrupt_frac_t_0p0_0p2=0.2050 acc_corrupt_t_0p2_0p4=0.4478 corrupt_frac_t_0p2_0p4=0.2076 acc_corrupt_t_0p4_0p6=0.6575 corrupt_frac_t_0p4_0p6=0.1999 acc_corrupt_t_0p6_0p8=0.8126 corrupt_frac_t_0p6_0p8=0.1932 acc_corrupt_t_0p8_1p0=0.9418 corrupt_frac_t_0p8_1p0=0.1944 out_w_norm=111.5440 out_g_norm=0.2298 loss_all=1.2049 init_gold_top10=0.4782 init_gold_top100=0.5008 +step=2400 micro_steps=4800 elapsed=45.2s lr=5.999998e-04 loss=2.4806 loss_recon=2.4806 loss_meanflow=0.0000 mean_model_t=0.4995 mean_corrupt_t=0.4995 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7885 corrupt_frac=0.5415 acc_corrupt=0.6208 loss_corrupt=2.4806 wrong_frac=0.4994 init_acc_corrupt=0.4660 acc_corrupt_t_0p0_0p2=0.2381 corrupt_frac_t_0p0_0p2=0.1991 acc_corrupt_t_0p2_0p4=0.4542 corrupt_frac_t_0p2_0p4=0.1972 acc_corrupt_t_0p4_0p6=0.6615 corrupt_frac_t_0p4_0p6=0.2122 acc_corrupt_t_0p6_0p8=0.8105 corrupt_frac_t_0p6_0p8=0.1977 acc_corrupt_t_0p8_1p0=0.9436 corrupt_frac_t_0p8_1p0=0.1959 out_w_norm=113.1534 out_g_norm=0.2234 loss_all=1.2586 init_gold_top10=0.4605 init_gold_top100=0.4931 +step=2450 micro_steps=4900 elapsed=45.2s lr=5.999997e-04 loss=2.5030 loss_recon=2.5030 loss_meanflow=0.0000 mean_model_t=0.4982 mean_corrupt_t=0.4982 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7821 corrupt_frac=0.5535 acc_corrupt=0.6176 loss_corrupt=2.5030 wrong_frac=0.5042 init_acc_corrupt=0.4618 acc_corrupt_t_0p0_0p2=0.2351 corrupt_frac_t_0p0_0p2=0.1982 acc_corrupt_t_0p2_0p4=0.4489 corrupt_frac_t_0p2_0p4=0.2022 acc_corrupt_t_0p4_0p6=0.6596 corrupt_frac_t_0p4_0p6=0.2054 acc_corrupt_t_0p6_0p8=0.8088 corrupt_frac_t_0p6_0p8=0.1934 acc_corrupt_t_0p8_1p0=0.9381 corrupt_frac_t_0p8_1p0=0.2008 out_w_norm=114.6664 out_g_norm=0.2157 loss_all=1.4379 init_gold_top10=0.4876 init_gold_top100=0.5284 +step=2500 micro_steps=5000 elapsed=45.2s lr=5.999996e-04 loss=2.3970 loss_recon=2.3970 loss_meanflow=0.0000 mean_model_t=0.5058 mean_corrupt_t=0.5058 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7904 corrupt_frac=0.5534 acc_corrupt=0.6320 loss_corrupt=2.3970 wrong_frac=0.4889 init_acc_corrupt=0.4777 acc_corrupt_t_0p0_0p2=0.2467 corrupt_frac_t_0p0_0p2=0.1848 acc_corrupt_t_0p2_0p4=0.4475 corrupt_frac_t_0p2_0p4=0.1952 acc_corrupt_t_0p4_0p6=0.6607 corrupt_frac_t_0p4_0p6=0.2090 acc_corrupt_t_0p6_0p8=0.8174 corrupt_frac_t_0p6_0p8=0.2047 acc_corrupt_t_0p8_1p0=0.9389 corrupt_frac_t_0p8_1p0=0.2063 out_w_norm=116.1243 out_g_norm=0.2132 loss_all=1.0542 init_gold_top10=0.5716 init_gold_top100=0.5944 +step=2550 micro_steps=5100 elapsed=45.2s lr=5.999996e-04 loss=2.5153 loss_recon=2.5153 loss_meanflow=0.0000 mean_model_t=0.4956 mean_corrupt_t=0.4956 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7834 corrupt_frac=0.5468 acc_corrupt=0.6149 loss_corrupt=2.5153 wrong_frac=0.5081 init_acc_corrupt=0.4566 acc_corrupt_t_0p0_0p2=0.2341 corrupt_frac_t_0p0_0p2=0.2045 acc_corrupt_t_0p2_0p4=0.4567 corrupt_frac_t_0p2_0p4=0.2025 acc_corrupt_t_0p4_0p6=0.6620 corrupt_frac_t_0p4_0p6=0.2172 acc_corrupt_t_0p6_0p8=0.8137 corrupt_frac_t_0p6_0p8=0.1856 acc_corrupt_t_0p8_1p0=0.9436 corrupt_frac_t_0p8_1p0=0.1923 out_w_norm=117.5333 out_g_norm=0.2123 loss_all=1.1704 init_gold_top10=0.5557 init_gold_top100=0.5828 +step=2600 micro_steps=5200 elapsed=45.2s lr=5.999995e-04 loss=2.5004 loss_recon=2.5004 loss_meanflow=0.0000 mean_model_t=0.4956 mean_corrupt_t=0.4956 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7837 corrupt_frac=0.5496 acc_corrupt=0.6173 loss_corrupt=2.5004 wrong_frac=0.5076 init_acc_corrupt=0.4571 acc_corrupt_t_0p0_0p2=0.2401 corrupt_frac_t_0p0_0p2=0.1982 acc_corrupt_t_0p2_0p4=0.4558 corrupt_frac_t_0p2_0p4=0.2139 acc_corrupt_t_0p4_0p6=0.6623 corrupt_frac_t_0p4_0p6=0.2011 acc_corrupt_t_0p6_0p8=0.8192 corrupt_frac_t_0p6_0p8=0.2028 acc_corrupt_t_0p8_1p0=0.9396 corrupt_frac_t_0p8_1p0=0.1840 out_w_norm=118.8799 out_g_norm=0.2098 loss_all=1.1687 init_gold_top10=0.5320 init_gold_top100=0.5546 +step=2650 micro_steps=5300 elapsed=45.2s lr=5.999994e-04 loss=2.5290 loss_recon=2.5290 loss_meanflow=0.0000 mean_model_t=0.4919 mean_corrupt_t=0.4919 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7788 corrupt_frac=0.5547 acc_corrupt=0.6127 loss_corrupt=2.5290 wrong_frac=0.5116 init_acc_corrupt=0.4518 acc_corrupt_t_0p0_0p2=0.2431 corrupt_frac_t_0p0_0p2=0.2169 acc_corrupt_t_0p2_0p4=0.4565 corrupt_frac_t_0p2_0p4=0.2099 acc_corrupt_t_0p4_0p6=0.6649 corrupt_frac_t_0p4_0p6=0.1855 acc_corrupt_t_0p6_0p8=0.8163 corrupt_frac_t_0p6_0p8=0.1950 acc_corrupt_t_0p8_1p0=0.9422 corrupt_frac_t_0p8_1p0=0.1928 out_w_norm=120.1524 out_g_norm=0.2109 loss_all=1.5819 init_gold_top10=0.4211 init_gold_top100=0.4596 +step=2700 micro_steps=5400 elapsed=45.2s lr=5.999993e-04 loss=2.4203 loss_recon=2.4203 loss_meanflow=0.0000 mean_model_t=0.5034 mean_corrupt_t=0.5034 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7893 corrupt_frac=0.5509 acc_corrupt=0.6287 loss_corrupt=2.4203 wrong_frac=0.4964 init_acc_corrupt=0.4690 acc_corrupt_t_0p0_0p2=0.2442 corrupt_frac_t_0p0_0p2=0.2025 acc_corrupt_t_0p2_0p4=0.4556 corrupt_frac_t_0p2_0p4=0.1916 acc_corrupt_t_0p4_0p6=0.6654 corrupt_frac_t_0p4_0p6=0.1964 acc_corrupt_t_0p6_0p8=0.8200 corrupt_frac_t_0p6_0p8=0.2014 acc_corrupt_t_0p8_1p0=0.9426 corrupt_frac_t_0p8_1p0=0.2080 out_w_norm=121.3969 out_g_norm=0.2023 loss_all=1.4758 init_gold_top10=0.5014 init_gold_top100=0.5307 +step=2750 micro_steps=5500 elapsed=45.2s lr=5.999992e-04 loss=2.4382 loss_recon=2.4382 loss_meanflow=0.0000 mean_model_t=0.4969 mean_corrupt_t=0.4969 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7870 corrupt_frac=0.5508 acc_corrupt=0.6246 loss_corrupt=2.4382 wrong_frac=0.5034 init_acc_corrupt=0.4623 acc_corrupt_t_0p0_0p2=0.2497 corrupt_frac_t_0p0_0p2=0.2061 acc_corrupt_t_0p2_0p4=0.4613 corrupt_frac_t_0p2_0p4=0.1925 acc_corrupt_t_0p4_0p6=0.6625 corrupt_frac_t_0p4_0p6=0.2043 acc_corrupt_t_0p6_0p8=0.8156 corrupt_frac_t_0p6_0p8=0.2023 acc_corrupt_t_0p8_1p0=0.9444 corrupt_frac_t_0p8_1p0=0.1948 out_w_norm=122.5845 out_g_norm=0.1981 loss_all=1.5147 init_gold_top10=0.4975 init_gold_top100=0.5255 +step=2800 micro_steps=5600 elapsed=45.2s lr=5.999990e-04 loss=2.4309 loss_recon=2.4309 loss_meanflow=0.0000 mean_model_t=0.4979 mean_corrupt_t=0.4979 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7889 corrupt_frac=0.5496 acc_corrupt=0.6267 loss_corrupt=2.4309 wrong_frac=0.5018 init_acc_corrupt=0.4635 acc_corrupt_t_0p0_0p2=0.2444 corrupt_frac_t_0p0_0p2=0.1947 acc_corrupt_t_0p2_0p4=0.4596 corrupt_frac_t_0p2_0p4=0.2071 acc_corrupt_t_0p4_0p6=0.6685 corrupt_frac_t_0p4_0p6=0.2051 acc_corrupt_t_0p6_0p8=0.8226 corrupt_frac_t_0p6_0p8=0.1985 acc_corrupt_t_0p8_1p0=0.9432 corrupt_frac_t_0p8_1p0=0.1945 out_w_norm=123.7159 out_g_norm=0.2013 loss_all=1.9089 init_gold_top10=0.3696 init_gold_top100=0.4217 +step=2850 micro_steps=5700 elapsed=45.2s lr=5.999989e-04 loss=2.3699 loss_recon=2.3699 loss_meanflow=0.0000 mean_model_t=0.5075 mean_corrupt_t=0.5075 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7928 corrupt_frac=0.5507 acc_corrupt=0.6343 loss_corrupt=2.3699 wrong_frac=0.4932 init_acc_corrupt=0.4734 acc_corrupt_t_0p0_0p2=0.2446 corrupt_frac_t_0p0_0p2=0.1903 acc_corrupt_t_0p2_0p4=0.4606 corrupt_frac_t_0p2_0p4=0.1996 acc_corrupt_t_0p4_0p6=0.6689 corrupt_frac_t_0p4_0p6=0.2016 acc_corrupt_t_0p6_0p8=0.8195 corrupt_frac_t_0p6_0p8=0.2018 acc_corrupt_t_0p8_1p0=0.9463 corrupt_frac_t_0p8_1p0=0.2067 out_w_norm=124.8125 out_g_norm=0.1948 loss_all=1.4297 init_gold_top10=0.4560 init_gold_top100=0.4964 +step=2900 micro_steps=5800 elapsed=45.2s lr=5.999988e-04 loss=2.4497 loss_recon=2.4497 loss_meanflow=0.0000 mean_model_t=0.4971 mean_corrupt_t=0.4971 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7842 corrupt_frac=0.5582 acc_corrupt=0.6239 loss_corrupt=2.4497 wrong_frac=0.5042 init_acc_corrupt=0.4615 acc_corrupt_t_0p0_0p2=0.2374 corrupt_frac_t_0p0_0p2=0.2001 acc_corrupt_t_0p2_0p4=0.4528 corrupt_frac_t_0p2_0p4=0.1943 acc_corrupt_t_0p4_0p6=0.6699 corrupt_frac_t_0p4_0p6=0.2152 acc_corrupt_t_0p6_0p8=0.8212 corrupt_frac_t_0p6_0p8=0.2054 acc_corrupt_t_0p8_1p0=0.9412 corrupt_frac_t_0p8_1p0=0.1892 out_w_norm=125.8479 out_g_norm=0.1902 loss_all=1.9794 init_gold_top10=0.5012 init_gold_top100=0.5379 +step=2950 micro_steps=5900 elapsed=45.2s lr=5.999987e-04 loss=2.4219 loss_recon=2.4219 loss_meanflow=0.0000 mean_model_t=0.4937 mean_corrupt_t=0.4937 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7877 corrupt_frac=0.5523 acc_corrupt=0.6271 loss_corrupt=2.4219 wrong_frac=0.5036 init_acc_corrupt=0.4620 acc_corrupt_t_0p0_0p2=0.2494 corrupt_frac_t_0p0_0p2=0.2019 acc_corrupt_t_0p2_0p4=0.4623 corrupt_frac_t_0p2_0p4=0.2075 acc_corrupt_t_0p4_0p6=0.6698 corrupt_frac_t_0p4_0p6=0.1930 acc_corrupt_t_0p6_0p8=0.8196 corrupt_frac_t_0p6_0p8=0.1940 acc_corrupt_t_0p8_1p0=0.9446 corrupt_frac_t_0p8_1p0=0.2057 out_w_norm=126.8189 out_g_norm=0.1959 loss_all=1.4185 init_gold_top10=0.4935 init_gold_top100=0.5246 +step=3000 micro_steps=6000 elapsed=45.2s lr=5.999985e-04 loss=2.3673 loss_recon=2.3673 loss_meanflow=0.0000 mean_model_t=0.5016 mean_corrupt_t=0.5016 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7922 corrupt_frac=0.5525 acc_corrupt=0.6346 loss_corrupt=2.3673 wrong_frac=0.4949 init_acc_corrupt=0.4708 acc_corrupt_t_0p0_0p2=0.2509 corrupt_frac_t_0p0_0p2=0.2013 acc_corrupt_t_0p2_0p4=0.4616 corrupt_frac_t_0p2_0p4=0.1942 acc_corrupt_t_0p4_0p6=0.6764 corrupt_frac_t_0p4_0p6=0.1972 acc_corrupt_t_0p6_0p8=0.8226 corrupt_frac_t_0p6_0p8=0.1941 acc_corrupt_t_0p8_1p0=0.9443 corrupt_frac_t_0p8_1p0=0.2132 out_w_norm=127.7662 out_g_norm=0.1888 loss_all=1.3513 init_gold_top10=0.5676 init_gold_top100=0.5939 +step=3050 micro_steps=6100 elapsed=47.0s lr=5.999984e-04 loss=2.4715 loss_recon=2.4715 loss_meanflow=0.0000 mean_model_t=0.4947 mean_corrupt_t=0.4947 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7822 corrupt_frac=0.5577 acc_corrupt=0.6203 loss_corrupt=2.4715 wrong_frac=0.5092 init_acc_corrupt=0.4552 acc_corrupt_t_0p0_0p2=0.2489 corrupt_frac_t_0p0_0p2=0.2124 acc_corrupt_t_0p2_0p4=0.4586 corrupt_frac_t_0p2_0p4=0.1967 acc_corrupt_t_0p4_0p6=0.6691 corrupt_frac_t_0p4_0p6=0.2074 acc_corrupt_t_0p6_0p8=0.8202 corrupt_frac_t_0p6_0p8=0.1887 acc_corrupt_t_0p8_1p0=0.9431 corrupt_frac_t_0p8_1p0=0.1947 out_w_norm=128.6790 out_g_norm=0.1878 loss_all=1.7912 init_gold_top10=0.4217 init_gold_top100=0.4595 +step=3100 micro_steps=6200 elapsed=45.2s lr=5.999982e-04 loss=2.4097 loss_recon=2.4097 loss_meanflow=0.0000 mean_model_t=0.4933 mean_corrupt_t=0.4933 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7847 corrupt_frac=0.5600 acc_corrupt=0.6269 loss_corrupt=2.4097 wrong_frac=0.5063 init_acc_corrupt=0.4597 acc_corrupt_t_0p0_0p2=0.2486 corrupt_frac_t_0p0_0p2=0.2094 acc_corrupt_t_0p2_0p4=0.4697 corrupt_frac_t_0p2_0p4=0.1944 acc_corrupt_t_0p4_0p6=0.6730 corrupt_frac_t_0p4_0p6=0.2078 acc_corrupt_t_0p6_0p8=0.8250 corrupt_frac_t_0p6_0p8=0.1956 acc_corrupt_t_0p8_1p0=0.9455 corrupt_frac_t_0p8_1p0=0.1928 out_w_norm=129.5505 out_g_norm=0.1919 loss_all=1.3227 init_gold_top10=0.5415 init_gold_top100=0.5693 +step=3150 micro_steps=6300 elapsed=45.2s lr=5.999980e-04 loss=2.4036 loss_recon=2.4036 loss_meanflow=0.0000 mean_model_t=0.4958 mean_corrupt_t=0.4958 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7906 corrupt_frac=0.5450 acc_corrupt=0.6273 loss_corrupt=2.4036 wrong_frac=0.5078 init_acc_corrupt=0.4572 acc_corrupt_t_0p0_0p2=0.2628 corrupt_frac_t_0p0_0p2=0.2009 acc_corrupt_t_0p2_0p4=0.4625 corrupt_frac_t_0p2_0p4=0.2132 acc_corrupt_t_0p4_0p6=0.6721 corrupt_frac_t_0p4_0p6=0.2078 acc_corrupt_t_0p6_0p8=0.8251 corrupt_frac_t_0p6_0p8=0.1818 acc_corrupt_t_0p8_1p0=0.9487 corrupt_frac_t_0p8_1p0=0.1963 out_w_norm=130.3675 out_g_norm=0.1883 loss_all=1.6678 init_gold_top10=0.4613 init_gold_top100=0.4993 +step=3200 micro_steps=6400 elapsed=45.2s lr=5.999979e-04 loss=2.3963 loss_recon=2.3963 loss_meanflow=0.0000 mean_model_t=0.4978 mean_corrupt_t=0.4978 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7893 corrupt_frac=0.5518 acc_corrupt=0.6293 loss_corrupt=2.3963 wrong_frac=0.5034 init_acc_corrupt=0.4615 acc_corrupt_t_0p0_0p2=0.2536 corrupt_frac_t_0p0_0p2=0.2012 acc_corrupt_t_0p2_0p4=0.4679 corrupt_frac_t_0p2_0p4=0.2171 acc_corrupt_t_0p4_0p6=0.6787 corrupt_frac_t_0p4_0p6=0.1876 acc_corrupt_t_0p6_0p8=0.8229 corrupt_frac_t_0p6_0p8=0.2009 acc_corrupt_t_0p8_1p0=0.9472 corrupt_frac_t_0p8_1p0=0.1954 out_w_norm=131.1617 out_g_norm=0.1865 loss_all=1.6213 init_gold_top10=0.4944 init_gold_top100=0.5224 +step=3250 micro_steps=6500 elapsed=45.2s lr=5.999977e-04 loss=2.3644 loss_recon=2.3644 loss_meanflow=0.0000 mean_model_t=0.5032 mean_corrupt_t=0.5032 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7917 corrupt_frac=0.5553 acc_corrupt=0.6352 loss_corrupt=2.3644 wrong_frac=0.4963 init_acc_corrupt=0.4685 acc_corrupt_t_0p0_0p2=0.2499 corrupt_frac_t_0p0_0p2=0.2022 acc_corrupt_t_0p2_0p4=0.4696 corrupt_frac_t_0p2_0p4=0.1995 acc_corrupt_t_0p4_0p6=0.6771 corrupt_frac_t_0p4_0p6=0.1957 acc_corrupt_t_0p6_0p8=0.8276 corrupt_frac_t_0p6_0p8=0.1938 acc_corrupt_t_0p8_1p0=0.9486 corrupt_frac_t_0p8_1p0=0.2088 out_w_norm=131.9091 out_g_norm=0.1855 loss_all=1.1771 init_gold_top10=0.5220 init_gold_top100=0.5475 +step=3300 micro_steps=6600 elapsed=45.3s lr=5.999975e-04 loss=2.3551 loss_recon=2.3551 loss_meanflow=0.0000 mean_model_t=0.5013 mean_corrupt_t=0.5013 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7952 corrupt_frac=0.5446 acc_corrupt=0.6346 loss_corrupt=2.3551 wrong_frac=0.5008 init_acc_corrupt=0.4656 acc_corrupt_t_0p0_0p2=0.2494 corrupt_frac_t_0p0_0p2=0.1965 acc_corrupt_t_0p2_0p4=0.4650 corrupt_frac_t_0p2_0p4=0.2018 acc_corrupt_t_0p4_0p6=0.6831 corrupt_frac_t_0p4_0p6=0.2059 acc_corrupt_t_0p6_0p8=0.8270 corrupt_frac_t_0p6_0p8=0.1956 acc_corrupt_t_0p8_1p0=0.9457 corrupt_frac_t_0p8_1p0=0.2002 out_w_norm=132.6215 out_g_norm=0.1905 loss_all=1.3746 init_gold_top10=0.5187 init_gold_top100=0.5494 +step=3350 micro_steps=6700 elapsed=45.2s lr=5.999973e-04 loss=2.3443 loss_recon=2.3443 loss_meanflow=0.0000 mean_model_t=0.5008 mean_corrupt_t=0.5008 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7927 corrupt_frac=0.5530 acc_corrupt=0.6361 loss_corrupt=2.3443 wrong_frac=0.4972 init_acc_corrupt=0.4684 acc_corrupt_t_0p0_0p2=0.2566 corrupt_frac_t_0p0_0p2=0.1942 acc_corrupt_t_0p2_0p4=0.4574 corrupt_frac_t_0p2_0p4=0.1962 acc_corrupt_t_0p4_0p6=0.6776 corrupt_frac_t_0p4_0p6=0.2077 acc_corrupt_t_0p6_0p8=0.8270 corrupt_frac_t_0p6_0p8=0.2025 acc_corrupt_t_0p8_1p0=0.9445 corrupt_frac_t_0p8_1p0=0.1995 out_w_norm=133.3192 out_g_norm=0.1839 loss_all=1.5987 init_gold_top10=0.3917 init_gold_top100=0.4320 +step=3400 micro_steps=6800 elapsed=45.3s lr=5.999971e-04 loss=2.3637 loss_recon=2.3637 loss_meanflow=0.0000 mean_model_t=0.4978 mean_corrupt_t=0.4978 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7899 corrupt_frac=0.5560 acc_corrupt=0.6327 loss_corrupt=2.3637 wrong_frac=0.5031 init_acc_corrupt=0.4621 acc_corrupt_t_0p0_0p2=0.2565 corrupt_frac_t_0p0_0p2=0.1988 acc_corrupt_t_0p2_0p4=0.4765 corrupt_frac_t_0p2_0p4=0.2114 acc_corrupt_t_0p4_0p6=0.6765 corrupt_frac_t_0p4_0p6=0.1974 acc_corrupt_t_0p6_0p8=0.8264 corrupt_frac_t_0p6_0p8=0.1963 acc_corrupt_t_0p8_1p0=0.9436 corrupt_frac_t_0p8_1p0=0.1978 out_w_norm=133.9930 out_g_norm=0.1822 loss_all=1.4353 init_gold_top10=0.5339 init_gold_top100=0.5562 +step=3450 micro_steps=6900 elapsed=45.2s lr=5.999969e-04 loss=2.3972 loss_recon=2.3972 loss_meanflow=0.0000 mean_model_t=0.4936 mean_corrupt_t=0.4936 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7865 corrupt_frac=0.5584 acc_corrupt=0.6290 loss_corrupt=2.3972 wrong_frac=0.5065 init_acc_corrupt=0.4585 acc_corrupt_t_0p0_0p2=0.2515 corrupt_frac_t_0p0_0p2=0.2113 acc_corrupt_t_0p2_0p4=0.4676 corrupt_frac_t_0p2_0p4=0.1930 acc_corrupt_t_0p4_0p6=0.6779 corrupt_frac_t_0p4_0p6=0.2035 acc_corrupt_t_0p6_0p8=0.8260 corrupt_frac_t_0p6_0p8=0.1953 acc_corrupt_t_0p8_1p0=0.9465 corrupt_frac_t_0p8_1p0=0.1968 out_w_norm=134.6320 out_g_norm=0.1840 loss_all=1.3615 init_gold_top10=0.4868 init_gold_top100=0.5089 +step=3500 micro_steps=7000 elapsed=45.3s lr=5.999967e-04 loss=2.3242 loss_recon=2.3242 loss_meanflow=0.0000 mean_model_t=0.5048 mean_corrupt_t=0.5048 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7930 corrupt_frac=0.5569 acc_corrupt=0.6391 loss_corrupt=2.3242 wrong_frac=0.4954 init_acc_corrupt=0.4709 acc_corrupt_t_0p0_0p2=0.2524 corrupt_frac_t_0p0_0p2=0.1992 acc_corrupt_t_0p2_0p4=0.4713 corrupt_frac_t_0p2_0p4=0.1958 acc_corrupt_t_0p4_0p6=0.6783 corrupt_frac_t_0p4_0p6=0.1939 acc_corrupt_t_0p6_0p8=0.8235 corrupt_frac_t_0p6_0p8=0.2090 acc_corrupt_t_0p8_1p0=0.9480 corrupt_frac_t_0p8_1p0=0.2039 out_w_norm=135.2446 out_g_norm=0.1824 loss_all=1.8724 init_gold_top10=0.4118 init_gold_top100=0.4482 +step=3550 micro_steps=7100 elapsed=45.2s lr=5.999964e-04 loss=2.3566 loss_recon=2.3566 loss_meanflow=0.0000 mean_model_t=0.4979 mean_corrupt_t=0.4979 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7936 corrupt_frac=0.5477 acc_corrupt=0.6349 loss_corrupt=2.3566 wrong_frac=0.5010 init_acc_corrupt=0.4643 acc_corrupt_t_0p0_0p2=0.2530 corrupt_frac_t_0p0_0p2=0.2067 acc_corrupt_t_0p2_0p4=0.4718 corrupt_frac_t_0p2_0p4=0.1973 acc_corrupt_t_0p4_0p6=0.6803 corrupt_frac_t_0p4_0p6=0.1965 acc_corrupt_t_0p6_0p8=0.8237 corrupt_frac_t_0p6_0p8=0.1878 acc_corrupt_t_0p8_1p0=0.9491 corrupt_frac_t_0p8_1p0=0.2134 out_w_norm=135.8276 out_g_norm=0.1808 loss_all=1.5983 init_gold_top10=0.3833 init_gold_top100=0.4337 +step=3600 micro_steps=7200 elapsed=45.2s lr=5.999962e-04 loss=2.3059 loss_recon=2.3059 loss_meanflow=0.0000 mean_model_t=0.5036 mean_corrupt_t=0.5036 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7968 corrupt_frac=0.5490 acc_corrupt=0.6409 loss_corrupt=2.3059 wrong_frac=0.4969 init_acc_corrupt=0.4688 acc_corrupt_t_0p0_0p2=0.2586 corrupt_frac_t_0p0_0p2=0.1935 acc_corrupt_t_0p2_0p4=0.4653 corrupt_frac_t_0p2_0p4=0.1982 acc_corrupt_t_0p4_0p6=0.6805 corrupt_frac_t_0p4_0p6=0.1983 acc_corrupt_t_0p6_0p8=0.8307 corrupt_frac_t_0p6_0p8=0.2097 acc_corrupt_t_0p8_1p0=0.9457 corrupt_frac_t_0p8_1p0=0.2023 out_w_norm=136.3885 out_g_norm=0.1797 loss_all=1.5529 init_gold_top10=0.5134 init_gold_top100=0.5429 +step=3650 micro_steps=7300 elapsed=45.2s lr=5.999960e-04 loss=2.3594 loss_recon=2.3594 loss_meanflow=0.0000 mean_model_t=0.4977 mean_corrupt_t=0.4977 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7924 corrupt_frac=0.5490 acc_corrupt=0.6335 loss_corrupt=2.3594 wrong_frac=0.5033 init_acc_corrupt=0.4604 acc_corrupt_t_0p0_0p2=0.2568 corrupt_frac_t_0p0_0p2=0.2100 acc_corrupt_t_0p2_0p4=0.4684 corrupt_frac_t_0p2_0p4=0.2017 acc_corrupt_t_0p4_0p6=0.6825 corrupt_frac_t_0p4_0p6=0.1851 acc_corrupt_t_0p6_0p8=0.8298 corrupt_frac_t_0p6_0p8=0.1966 acc_corrupt_t_0p8_1p0=0.9470 corrupt_frac_t_0p8_1p0=0.2064 out_w_norm=136.9433 out_g_norm=0.1800 loss_all=1.1706 init_gold_top10=0.5309 init_gold_top100=0.5538 +step=3700 micro_steps=7400 elapsed=45.2s lr=5.999957e-04 loss=2.2650 loss_recon=2.2650 loss_meanflow=0.0000 mean_model_t=0.5029 mean_corrupt_t=0.5029 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.8010 corrupt_frac=0.5449 acc_corrupt=0.6454 loss_corrupt=2.2650 wrong_frac=0.4934 init_acc_corrupt=0.4730 acc_corrupt_t_0p0_0p2=0.2614 corrupt_frac_t_0p0_0p2=0.1958 acc_corrupt_t_0p2_0p4=0.4823 corrupt_frac_t_0p2_0p4=0.1959 acc_corrupt_t_0p4_0p6=0.6819 corrupt_frac_t_0p4_0p6=0.1989 acc_corrupt_t_0p6_0p8=0.8321 corrupt_frac_t_0p6_0p8=0.2087 acc_corrupt_t_0p8_1p0=0.9489 corrupt_frac_t_0p8_1p0=0.2007 out_w_norm=137.4403 out_g_norm=0.1773 loss_all=1.5302 init_gold_top10=0.4740 init_gold_top100=0.5061 +step=3750 micro_steps=7500 elapsed=45.2s lr=5.999954e-04 loss=2.2756 loss_recon=2.2756 loss_meanflow=0.0000 mean_model_t=0.5053 mean_corrupt_t=0.5053 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7990 corrupt_frac=0.5486 acc_corrupt=0.6448 loss_corrupt=2.2756 wrong_frac=0.4925 init_acc_corrupt=0.4742 acc_corrupt_t_0p0_0p2=0.2606 corrupt_frac_t_0p0_0p2=0.2006 acc_corrupt_t_0p2_0p4=0.4727 corrupt_frac_t_0p2_0p4=0.1885 acc_corrupt_t_0p4_0p6=0.6823 corrupt_frac_t_0p4_0p6=0.1931 acc_corrupt_t_0p6_0p8=0.8318 corrupt_frac_t_0p6_0p8=0.2118 acc_corrupt_t_0p8_1p0=0.9491 corrupt_frac_t_0p8_1p0=0.2060 out_w_norm=137.9213 out_g_norm=0.1758 loss_all=0.8307 init_gold_top10=0.6185 init_gold_top100=0.6341 +step=3800 micro_steps=7600 elapsed=45.2s lr=5.999952e-04 loss=2.2939 loss_recon=2.2939 loss_meanflow=0.0000 mean_model_t=0.5035 mean_corrupt_t=0.5035 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7935 corrupt_frac=0.5609 acc_corrupt=0.6424 loss_corrupt=2.2939 wrong_frac=0.4950 init_acc_corrupt=0.4710 acc_corrupt_t_0p0_0p2=0.2613 corrupt_frac_t_0p0_0p2=0.1943 acc_corrupt_t_0p2_0p4=0.4686 corrupt_frac_t_0p2_0p4=0.2000 acc_corrupt_t_0p4_0p6=0.6845 corrupt_frac_t_0p4_0p6=0.1941 acc_corrupt_t_0p6_0p8=0.8295 corrupt_frac_t_0p6_0p8=0.2119 acc_corrupt_t_0p8_1p0=0.9481 corrupt_frac_t_0p8_1p0=0.1996 out_w_norm=138.3954 out_g_norm=0.1758 loss_all=1.2377 init_gold_top10=0.4966 init_gold_top100=0.5136 +step=3850 micro_steps=7700 elapsed=45.2s lr=5.999949e-04 loss=2.3202 loss_recon=2.3202 loss_meanflow=0.0000 mean_model_t=0.4978 mean_corrupt_t=0.4978 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7949 corrupt_frac=0.5485 acc_corrupt=0.6372 loss_corrupt=2.3202 wrong_frac=0.5045 init_acc_corrupt=0.4610 acc_corrupt_t_0p0_0p2=0.2668 corrupt_frac_t_0p0_0p2=0.2016 acc_corrupt_t_0p2_0p4=0.4762 corrupt_frac_t_0p2_0p4=0.2053 acc_corrupt_t_0p4_0p6=0.6830 corrupt_frac_t_0p4_0p6=0.1963 acc_corrupt_t_0p6_0p8=0.8282 corrupt_frac_t_0p6_0p8=0.2091 acc_corrupt_t_0p8_1p0=0.9506 corrupt_frac_t_0p8_1p0=0.1877 out_w_norm=138.8407 out_g_norm=0.1773 loss_all=1.2445 init_gold_top10=0.4898 init_gold_top100=0.5220 +step=3900 micro_steps=7800 elapsed=45.2s lr=5.999946e-04 loss=2.3521 loss_recon=2.3521 loss_meanflow=0.0000 mean_model_t=0.4896 mean_corrupt_t=0.4896 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7905 corrupt_frac=0.5537 acc_corrupt=0.6333 loss_corrupt=2.3521 wrong_frac=0.5083 init_acc_corrupt=0.4574 acc_corrupt_t_0p0_0p2=0.2546 corrupt_frac_t_0p0_0p2=0.2180 acc_corrupt_t_0p2_0p4=0.4838 corrupt_frac_t_0p2_0p4=0.1901 acc_corrupt_t_0p4_0p6=0.6874 corrupt_frac_t_0p4_0p6=0.2063 acc_corrupt_t_0p6_0p8=0.8288 corrupt_frac_t_0p6_0p8=0.1815 acc_corrupt_t_0p8_1p0=0.9477 corrupt_frac_t_0p8_1p0=0.2058 out_w_norm=139.2641 out_g_norm=0.1774 loss_all=1.1569 init_gold_top10=0.5058 init_gold_top100=0.5292 +step=3950 micro_steps=7900 elapsed=45.2s lr=5.999943e-04 loss=2.2292 loss_recon=2.2292 loss_meanflow=0.0000 mean_model_t=0.5107 mean_corrupt_t=0.5107 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.8038 corrupt_frac=0.5449 acc_corrupt=0.6504 loss_corrupt=2.2292 wrong_frac=0.4889 init_acc_corrupt=0.4778 acc_corrupt_t_0p0_0p2=0.2592 corrupt_frac_t_0p0_0p2=0.1915 acc_corrupt_t_0p2_0p4=0.4806 corrupt_frac_t_0p2_0p4=0.1966 acc_corrupt_t_0p4_0p6=0.6893 corrupt_frac_t_0p4_0p6=0.1983 acc_corrupt_t_0p6_0p8=0.8311 corrupt_frac_t_0p6_0p8=0.2065 acc_corrupt_t_0p8_1p0=0.9489 corrupt_frac_t_0p8_1p0=0.2091 out_w_norm=139.6581 out_g_norm=0.1752 loss_all=1.7421 init_gold_top10=0.4244 init_gold_top100=0.4628 +step=4000 micro_steps=8000 elapsed=45.2s lr=5.999941e-04 loss=2.2955 loss_recon=2.2955 loss_meanflow=0.0000 mean_model_t=0.5039 mean_corrupt_t=0.5039 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7983 corrupt_frac=0.5448 acc_corrupt=0.6409 loss_corrupt=2.2955 wrong_frac=0.4997 init_acc_corrupt=0.4654 acc_corrupt_t_0p0_0p2=0.2638 corrupt_frac_t_0p0_0p2=0.2001 acc_corrupt_t_0p2_0p4=0.4780 corrupt_frac_t_0p2_0p4=0.2026 acc_corrupt_t_0p4_0p6=0.6821 corrupt_frac_t_0p4_0p6=0.1961 acc_corrupt_t_0p6_0p8=0.8279 corrupt_frac_t_0p6_0p8=0.1941 acc_corrupt_t_0p8_1p0=0.9492 corrupt_frac_t_0p8_1p0=0.2090 out_w_norm=140.0401 out_g_norm=0.1726 loss_all=1.1175 init_gold_top10=0.5671 init_gold_top100=0.5854 +step=4050 micro_steps=8100 elapsed=47.0s lr=5.999938e-04 loss=2.2443 loss_recon=2.2443 loss_meanflow=0.0000 mean_model_t=0.5032 mean_corrupt_t=0.5032 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.8046 corrupt_frac=0.5378 acc_corrupt=0.6477 loss_corrupt=2.2443 wrong_frac=0.4943 init_acc_corrupt=0.4724 acc_corrupt_t_0p0_0p2=0.2623 corrupt_frac_t_0p0_0p2=0.1917 acc_corrupt_t_0p2_0p4=0.4815 corrupt_frac_t_0p2_0p4=0.1960 acc_corrupt_t_0p4_0p6=0.6826 corrupt_frac_t_0p4_0p6=0.2098 acc_corrupt_t_0p6_0p8=0.8329 corrupt_frac_t_0p6_0p8=0.1945 acc_corrupt_t_0p8_1p0=0.9512 corrupt_frac_t_0p8_1p0=0.2079 out_w_norm=140.3787 out_g_norm=0.1740 loss_all=1.4397 init_gold_top10=0.4812 init_gold_top100=0.5195 +step=4100 micro_steps=8200 elapsed=45.2s lr=5.999934e-04 loss=2.2575 loss_recon=2.2575 loss_meanflow=0.0000 mean_model_t=0.5001 mean_corrupt_t=0.5001 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7981 corrupt_frac=0.5543 acc_corrupt=0.6461 loss_corrupt=2.2575 wrong_frac=0.4962 init_acc_corrupt=0.4713 acc_corrupt_t_0p0_0p2=0.2604 corrupt_frac_t_0p0_0p2=0.1952 acc_corrupt_t_0p2_0p4=0.4863 corrupt_frac_t_0p2_0p4=0.1899 acc_corrupt_t_0p4_0p6=0.6808 corrupt_frac_t_0p4_0p6=0.2052 acc_corrupt_t_0p6_0p8=0.8286 corrupt_frac_t_0p6_0p8=0.2122 acc_corrupt_t_0p8_1p0=0.9489 corrupt_frac_t_0p8_1p0=0.1974 out_w_norm=140.7012 out_g_norm=0.1709 loss_all=1.8329 init_gold_top10=0.4913 init_gold_top100=0.5168 +step=4150 micro_steps=8300 elapsed=45.2s lr=5.999931e-04 loss=2.3150 loss_recon=2.3150 loss_meanflow=0.0000 mean_model_t=0.4994 mean_corrupt_t=0.4994 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7931 corrupt_frac=0.5559 acc_corrupt=0.6388 loss_corrupt=2.3150 wrong_frac=0.5026 init_acc_corrupt=0.4619 acc_corrupt_t_0p0_0p2=0.2617 corrupt_frac_t_0p0_0p2=0.2089 acc_corrupt_t_0p2_0p4=0.4783 corrupt_frac_t_0p2_0p4=0.1972 acc_corrupt_t_0p4_0p6=0.6814 corrupt_frac_t_0p4_0p6=0.1869 acc_corrupt_t_0p6_0p8=0.8302 corrupt_frac_t_0p6_0p8=0.2000 acc_corrupt_t_0p8_1p0=0.9488 corrupt_frac_t_0p8_1p0=0.2070 out_w_norm=141.0129 out_g_norm=0.1736 loss_all=1.4118 init_gold_top10=0.4607 init_gold_top100=0.4903 +step=4200 micro_steps=8400 elapsed=45.2s lr=5.999928e-04 loss=2.3060 loss_recon=2.3060 loss_meanflow=0.0000 mean_model_t=0.4945 mean_corrupt_t=0.4945 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7974 corrupt_frac=0.5442 acc_corrupt=0.6390 loss_corrupt=2.3060 wrong_frac=0.5055 init_acc_corrupt=0.4594 acc_corrupt_t_0p0_0p2=0.2660 corrupt_frac_t_0p0_0p2=0.2008 acc_corrupt_t_0p2_0p4=0.4770 corrupt_frac_t_0p2_0p4=0.2102 acc_corrupt_t_0p4_0p6=0.6873 corrupt_frac_t_0p4_0p6=0.1900 acc_corrupt_t_0p6_0p8=0.8345 corrupt_frac_t_0p6_0p8=0.2127 acc_corrupt_t_0p8_1p0=0.9515 corrupt_frac_t_0p8_1p0=0.1863 out_w_norm=141.3208 out_g_norm=0.1713 loss_all=1.2601 init_gold_top10=0.5089 init_gold_top100=0.5320 +step=4250 micro_steps=8500 elapsed=45.2s lr=5.999925e-04 loss=2.2227 loss_recon=2.2227 loss_meanflow=0.0000 mean_model_t=0.5084 mean_corrupt_t=0.5084 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.8024 corrupt_frac=0.5499 acc_corrupt=0.6514 loss_corrupt=2.2227 wrong_frac=0.4910 init_acc_corrupt=0.4767 acc_corrupt_t_0p0_0p2=0.2607 corrupt_frac_t_0p0_0p2=0.1890 acc_corrupt_t_0p2_0p4=0.4857 corrupt_frac_t_0p2_0p4=0.1973 acc_corrupt_t_0p4_0p6=0.6865 corrupt_frac_t_0p4_0p6=0.1983 acc_corrupt_t_0p6_0p8=0.8347 corrupt_frac_t_0p6_0p8=0.2097 acc_corrupt_t_0p8_1p0=0.9485 corrupt_frac_t_0p8_1p0=0.2057 out_w_norm=141.6021 out_g_norm=0.1718 loss_all=1.4447 init_gold_top10=0.3887 init_gold_top100=0.4351 +step=4300 micro_steps=8600 elapsed=45.2s lr=5.999921e-04 loss=2.2350 loss_recon=2.2350 loss_meanflow=0.0000 mean_model_t=0.5040 mean_corrupt_t=0.5040 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.8004 corrupt_frac=0.5497 acc_corrupt=0.6478 loss_corrupt=2.2350 wrong_frac=0.4953 init_acc_corrupt=0.4713 acc_corrupt_t_0p0_0p2=0.2637 corrupt_frac_t_0p0_0p2=0.1936 acc_corrupt_t_0p2_0p4=0.4828 corrupt_frac_t_0p2_0p4=0.1965 acc_corrupt_t_0p4_0p6=0.6867 corrupt_frac_t_0p4_0p6=0.2078 acc_corrupt_t_0p6_0p8=0.8310 corrupt_frac_t_0p6_0p8=0.1891 acc_corrupt_t_0p8_1p0=0.9486 corrupt_frac_t_0p8_1p0=0.2130 out_w_norm=141.8717 out_g_norm=0.1710 loss_all=1.1689 init_gold_top10=0.4280 init_gold_top100=0.4692 +step=4350 micro_steps=8700 elapsed=45.2s lr=5.999918e-04 loss=2.2860 loss_recon=2.2860 loss_meanflow=0.0000 mean_model_t=0.4988 mean_corrupt_t=0.4988 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.8004 corrupt_frac=0.5411 acc_corrupt=0.6418 loss_corrupt=2.2860 wrong_frac=0.5032 init_acc_corrupt=0.4620 acc_corrupt_t_0p0_0p2=0.2691 corrupt_frac_t_0p0_0p2=0.1999 acc_corrupt_t_0p2_0p4=0.4817 corrupt_frac_t_0p2_0p4=0.2071 acc_corrupt_t_0p4_0p6=0.6874 corrupt_frac_t_0p4_0p6=0.1912 acc_corrupt_t_0p6_0p8=0.8293 corrupt_frac_t_0p6_0p8=0.2060 acc_corrupt_t_0p8_1p0=0.9500 corrupt_frac_t_0p8_1p0=0.1958 out_w_norm=142.1228 out_g_norm=0.1757 loss_all=0.8153 init_gold_top10=0.5566 init_gold_top100=0.5802 +step=4400 micro_steps=8800 elapsed=45.2s lr=5.999914e-04 loss=2.2060 loss_recon=2.2060 loss_meanflow=0.0000 mean_model_t=0.5086 mean_corrupt_t=0.5086 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.8058 corrupt_frac=0.5446 acc_corrupt=0.6531 loss_corrupt=2.2060 wrong_frac=0.4900 init_acc_corrupt=0.4764 acc_corrupt_t_0p0_0p2=0.2682 corrupt_frac_t_0p0_0p2=0.1791 acc_corrupt_t_0p2_0p4=0.4853 corrupt_frac_t_0p2_0p4=0.2072 acc_corrupt_t_0p4_0p6=0.6860 corrupt_frac_t_0p4_0p6=0.2053 acc_corrupt_t_0p6_0p8=0.8302 corrupt_frac_t_0p6_0p8=0.2028 acc_corrupt_t_0p8_1p0=0.9502 corrupt_frac_t_0p8_1p0=0.2056 out_w_norm=142.3659 out_g_norm=0.1704 loss_all=1.1609 init_gold_top10=0.5774 init_gold_top100=0.5993 +step=4450 micro_steps=8900 elapsed=45.2s lr=5.999911e-04 loss=2.3561 loss_recon=2.3561 loss_meanflow=0.0000 mean_model_t=0.4875 mean_corrupt_t=0.4875 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.7882 corrupt_frac=0.5573 acc_corrupt=0.6316 loss_corrupt=2.3561 wrong_frac=0.5134 init_acc_corrupt=0.4515 acc_corrupt_t_0p0_0p2=0.2642 corrupt_frac_t_0p0_0p2=0.2088 acc_corrupt_t_0p2_0p4=0.4771 corrupt_frac_t_0p2_0p4=0.2069 acc_corrupt_t_0p4_0p6=0.6819 corrupt_frac_t_0p4_0p6=0.1936 acc_corrupt_t_0p6_0p8=0.8305 corrupt_frac_t_0p6_0p8=0.2114 acc_corrupt_t_0p8_1p0=0.9486 corrupt_frac_t_0p8_1p0=0.1794 out_w_norm=142.6098 out_g_norm=0.1722 loss_all=1.0148 init_gold_top10=0.5676 init_gold_top100=0.5934 +step=4500 micro_steps=9000 elapsed=45.2s lr=5.999907e-04 loss=2.1915 loss_recon=2.1915 loss_meanflow=0.0000 mean_model_t=0.5092 mean_corrupt_t=0.5092 mean_loss_t_weight=1.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.8069 corrupt_frac=0.5427 acc_corrupt=0.6544 loss_corrupt=2.1915 wrong_frac=0.4900 init_acc_corrupt=0.4775 acc_corrupt_t_0p0_0p2=0.2648 corrupt_frac_t_0p0_0p2=0.1900 acc_corrupt_t_0p2_0p4=0.4864 corrupt_frac_t_0p2_0p4=0.1889 acc_corrupt_t_0p4_0p6=0.6892 corrupt_frac_t_0p4_0p6=0.2070 acc_corrupt_t_0p6_0p8=0.8343 corrupt_frac_t_0p6_0p8=0.2057 acc_corrupt_t_0p8_1p0=0.9497 corrupt_frac_t_0p8_1p0=0.2084 out_w_norm=142.8296 out_g_norm=0.1686 loss_all=1.3021 init_gold_top10=0.5303 init_gold_top100=0.5504 diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/cohere/modeling_cohere.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/cohere/modeling_cohere.py new file mode 100644 index 0000000000000000000000000000000000000000..b8bf50af9bf4da1ad78a44934463658d9e430110 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/cohere/modeling_cohere.py @@ -0,0 +1,530 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/cohere/modular_cohere.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_cohere.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2024 Cohere team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is based on the LLama model definition file in transformers + + +from collections.abc import Callable +from typing import Optional + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernelized_func +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_cohere import CohereConfig + + +class CohereLayerNorm(nn.Module): + def __init__(self, hidden_size=None, eps=1e-5, bias=False): + """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim""" + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + + +class CohereRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: CohereConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: CohereConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class CohereMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def rotate_half(x): + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + dtype = q.dtype + q = q.float() + k = k.float() + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) + + +@use_kernelized_func(apply_rotary_pos_emb) +class CohereAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: CohereConfig, layer_idx: int | None = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.use_qk_norm = config.use_qk_norm + if self.use_qk_norm: + # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads + self.q_norm = CohereLayerNorm( + hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps + ) + self.k_norm = CohereLayerNorm( + hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + if self.use_qk_norm: # main diff from Llama + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class CohereDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: CohereConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = CohereAttention(config=config, layer_idx=layer_idx) + self.mlp = CohereMLP(config) + self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + past_key_values (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states_attention, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states_mlp = self.mlp(hidden_states) + hidden_states = residual + hidden_states_attention + hidden_states_mlp + return hidden_states + + +@auto_docstring +class CoherePreTrainedModel(PreTrainedModel): + config: CohereConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["CohereDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": CohereDecoderLayer, + "attentions": CohereAttention, + } + + +@auto_docstring +class CohereModel(CoherePreTrainedModel): + def __init__(self, config: CohereConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) + self.rotary_emb = CohereRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_embeddings=position_embeddings, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = CohereModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.logit_scale = config.logit_scale + self.tie_word_embeddings = config.tie_word_embeddings + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >> from transformers import AutoTokenizer, CohereForCausalLM + + >> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01") + >> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") + + >> prompt = "Hey, are you conscious? Can you talk to me?" + >> inputs = tokenizer(prompt, return_tensors="pt") + + >> # Generate + >> generate_ids = model.generate(inputs.input_ids, max_length=30) + >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = logits * self.logit_scale # main diff from Llama + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"] diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/__init__.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33d80cdd3425f95de5d40c82a4f52132be971f1f --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_granitemoeshared import * + from .modeling_granitemoeshared import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/configuration_granitemoeshared.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/configuration_granitemoeshared.py new file mode 100644 index 0000000000000000000000000000000000000000..9d782c089e3699ec67372a622f2dafc371aee617 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/configuration_granitemoeshared.py @@ -0,0 +1,95 @@ +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GraniteMoeShared model configuration""" + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="ibm-granite/granite-speech-3.2-8b") +@strict +class GraniteMoeSharedConfig(PreTrainedConfig): + r""" + embedding_multiplier (`float`, *optional*, defaults to 1.0): + embedding multiplier + logits_scaling (`float`, *optional*, defaults to 1.0): + divisor for output logits + residual_multiplier (`float`, *optional*, defaults to 1.0): + residual multiplier + attention_multiplier (`float`, *optional*, defaults to 1.0): + attention multiplier + shared_intermediate_size (`int`, *optional*, defaults to 1024): + intermediate size for shared experts. + + ```python + >>> from transformers import GraniteMoeSharedModel, GraniteMoeSharedConfig + + >>> # Initializing a GraniteMoeShared granitemoe-3b style configuration + >>> configuration = GraniteMoeSharedConfig() + + >>> # Initializing a model from the granitemoe-7b style configuration + >>> model = GraniteMoeSharedModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "granitemoeshared" + keys_to_ignore_at_inference = ["past_key_values"] + + vocab_size: int = 32000 + hidden_size: int = 4096 + intermediate_size: int = 11008 + num_hidden_layers: int = 32 + num_attention_heads: int = 32 + num_key_value_heads: int | None = None + hidden_act: str = "silu" + max_position_embeddings: int = 2048 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int | None = None + bos_token_id: int | None = 1 + eos_token_id: int | list[int] | None = 2 + tie_word_embeddings: bool = False + rope_parameters: RopeParameters | dict | None = None + attention_bias: bool = False + attention_dropout: float | int | None = 0.0 + embedding_multiplier: float | int | None = 1.0 + logits_scaling: float | int | None = 1.0 + residual_multiplier: float | int | None = 1.0 + attention_multiplier: float | int | None = 1.0 + num_local_experts: int | None = 8 + num_experts_per_tok: int | None = 2 + output_router_logits: bool | None = False + router_aux_loss_coef: float | None = 0.001 + shared_intermediate_size: int = 0 + + def __post_init__(self, **kwargs): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + super().__post_init__(**kwargs) + + +__all__ = ["GraniteMoeSharedConfig"] diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modeling_granitemoeshared.py new file mode 100644 index 0000000000000000000000000000000000000000..71f8c6eaff7dc0836d997ff844f8623b1aaa20c6 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -0,0 +1,800 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/granitemoeshared/modular_granitemoeshared.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_granitemoeshared.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable +from typing import Optional, TypedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring +from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_granitemoeshared import GraniteMoeSharedConfig + + +class GraniteFlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage. + Use cases include padding-free training and fewer `torch.compile` graph breaks. + + cu_seq_lens_q (`torch.LongTensor`): + Gets cumulative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`): + Gets cumulative sequence length for key state. + max_length_q (`int`): + Maximum sequence length for query state. + max_length_k (`int`): + Maximum sequence length for key state. + seq_idx (`torch.IntTensor): + Index of each packed sequence. + """ + + cu_seq_lens_q: torch.LongTensor + cu_seq_lens_k: torch.LongTensor + max_length_q: int + max_length_k: int + seq_idx: torch.IntTensor + + +class GraniteMoeSharedMLP(nn.Module): + """ + MLP layer for shared experts + + Args: + config: + Configuration object with model hyperparameters. + """ + + def __init__(self, config: GraniteMoeSharedConfig): + super().__init__() + + self.input_size = config.hidden_size + self.hidden_size = config.shared_intermediate_size + self.activation = ACT2FN[config.hidden_act] + self.input_linear = nn.Linear(self.input_size, self.hidden_size * 2, bias=False) + self.output_linear = nn.Linear(self.hidden_size, self.input_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.input_linear(hidden_states) + chunked_hidden_states = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] + hidden_states = self.output_linear(hidden_states) + return hidden_states + + +@use_kernel_forward_from_hub("RMSNorm") +class GraniteMoeSharedRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + GraniteMoeSharedRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GraniteMoeSharedParallelExperts(nn.Module): + def __init__(self, num_experts: int, input_size: int, output_size: int) -> None: + """ + Initialize the GraniteMoeSharedParallelExperts module. + The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with + many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and + [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the + [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py) + used in vllm. + + Args: + num_experts (int): + Number of experts. + input_size (int): + Size of the input. + output_size (int): + Size of the output. + """ + super().__init__() + self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + + def forward(self, inputs, expert_size): + """ + Forward pass of the GraniteMoeSharedParallelExperts module. + + Args: + inputs (Tensor): + Input tensor. + expert_size: + Expert size information. + + Returns: + Tensor: Output tensor. + """ + input_list = inputs.split(expert_size, dim=0) + output_list = [] + for i in range(self.num_experts): + output_list.append(F.linear(input_list[i], self.weight[i])) + results = torch.cat(output_list, dim=0) + return results + + +class GraniteMoeSharedTopKGating(nn.Module): + def __init__(self, input_size: int, num_experts: int, top_k: int): + """ + Initialize the top-k gating mechanism. + + Args: + input_size (`int`): + Size of the input. + num_experts (`int`): + Number of experts. + top_k (`int`): + Number of top experts to select. + """ + super().__init__() + + self.num_experts = num_experts + self.input_size = input_size + self.top_k = top_k + + self.layer = nn.Linear(input_size, num_experts, bias=False) + + def forward(self, hidden_states): + # compute the top_k routing decision + logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts] + top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k] + top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k] + + # compute number of input given to each expert + zeros = torch.zeros( + [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device + ) # [num_tokens, num_experts] + gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts] + expert_size = gates.long().sum(0) # [num_experts,] + # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`) + # (and `DataDependentOutputException`) + expert_size = expert_size.tolist() + + # sort and group input tokens according to expert assignment + top_k_experts = top_k_indices.flatten() # [num_tokens * top_k] + _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k] + batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k] + + # gather the gate values for grouped input tokens + top_k_gates = top_k_gates.flatten() # [num_tokens * top_k] + batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k] + + return index_sorted_experts, batch_index, batch_gates, expert_size, logits + + +class GraniteMoeSharedMoE(nn.Module): + """ + A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. + + Args: + config: + Configuration object with model hyperparameters. + """ + + def __init__(self, config: GraniteMoeSharedConfig): + super().__init__() + + self.input_size = config.hidden_size + self.hidden_size = config.intermediate_size + self.activation = ACT2FN[config.hidden_act] + self.input_linear = GraniteMoeSharedParallelExperts( + config.num_local_experts, self.input_size, self.hidden_size * 2 + ) + self.output_linear = GraniteMoeSharedParallelExperts( + config.num_local_experts, self.hidden_size, self.input_size + ) + + self.router = GraniteMoeSharedTopKGating( + input_size=self.input_size, + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + ) + + def forward(self, layer_input): + bsz, length, emb_size = layer_input.size() + layer_input = layer_input.reshape(-1, emb_size) + _, batch_index, batch_gates, expert_size, _ = self.router(layer_input) + + expert_inputs = layer_input[batch_index] + hidden_states = self.input_linear(expert_inputs, expert_size) + chunked_hidden_states = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] + expert_outputs = self.output_linear(hidden_states, expert_size) + + expert_outputs = expert_outputs * batch_gates[:, None] + + zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device) + layer_output = zeros.index_add(0, batch_index, expert_outputs) + layer_output = layer_output.view(bsz, length, self.input_size) + return layer_output + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +@use_kernelized_func(apply_rotary_pos_emb) +class GraniteMoeSharedAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.attention_multiplier # Only diff with llama + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GraniteMoeSharedAttention(config=config, layer_idx=layer_idx) + self.input_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.block_sparse_moe = GraniteMoeSharedMoE(config) + self.residual_multiplier = config.residual_multiplier # Only diff with mixtral! + self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[GraniteFlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states * self.residual_multiplier + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + moe_hidden_states = self.block_sparse_moe(hidden_states) + + if self.shared_mlp is None: + hidden_states = moe_hidden_states + else: + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + return hidden_states + + +@auto_docstring +class GraniteMoeSharedPreTrainedModel(PreTrainedModel): + config: GraniteMoeSharedConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GraniteMoeSharedDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()" + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": GraniteMoeSharedDecoderLayer, + "attentions": GraniteMoeSharedAttention, + } + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, GraniteMoeSharedParallelExperts): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + + +class GraniteMoeSharedRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: GraniteMoeSharedConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: GraniteMoeSharedConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): + def __init__(self, config: GraniteMoeSharedConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [GraniteMoeSharedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GraniteMoeSharedRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.embedding_multiplier = config.embedding_multiplier + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + causal_mask = create_causal_mask( # ONLY DIFF WITH MIXTRAL: NO SLIDING + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + inputs_embeds = inputs_embeds * self.embedding_multiplier + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +def load_balancing_loss_func( + gate_logits: torch.Tensor | tuple[torch.Tensor] | None, + num_experts: int | None = None, + top_k=2, + attention_mask: torch.Tensor | None = None, +) -> torch.Tensor | int: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config: GraniteMoeSharedConfig): + super().__init__(config) + self.model = GraniteMoeSharedModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.logits_scaling = config.logits_scaling + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + output_router_logits: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs, + ) -> tuple | MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GraniteMoeSharedForCausalLM + + >>> model = GraniteMoeSharedForCausalLM.from_pretrained("ibm/PowerMoE-3b") + >>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerMoE-3b") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + # Only compute necessary logits + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = logits / self.config.logits_scaling + + loss = None + if labels is not None: + # Flatten the tokens + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"] diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modular_granitemoeshared.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modular_granitemoeshared.py new file mode 100644 index 0000000000000000000000000000000000000000..e51cd7712b9b2ffacf2b3493056636e7aeb23a83 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modular_granitemoeshared.py @@ -0,0 +1,154 @@ +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TypedDict + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...processing_utils import Unpack +from ...utils import logging +from ..granitemoe.modeling_granitemoe import ( + GraniteMoeDecoderLayer, + GraniteMoeForCausalLM, + GraniteMoeModel, + GraniteMoePreTrainedModel, +) +from .configuration_granitemoeshared import GraniteMoeSharedConfig + + +logger = logging.get_logger(__name__) + + +class GraniteFlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage. + Use cases include padding-free training and fewer `torch.compile` graph breaks. + + cu_seq_lens_q (`torch.LongTensor`): + Gets cumulative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`): + Gets cumulative sequence length for key state. + max_length_q (`int`): + Maximum sequence length for query state. + max_length_k (`int`): + Maximum sequence length for key state. + seq_idx (`torch.IntTensor): + Index of each packed sequence. + """ + + cu_seq_lens_q: torch.LongTensor + cu_seq_lens_k: torch.LongTensor + max_length_q: int + max_length_k: int + seq_idx: torch.IntTensor + + +class GraniteMoeSharedMLP(nn.Module): + """ + MLP layer for shared experts + + Args: + config: + Configuration object with model hyperparameters. + """ + + def __init__(self, config: GraniteMoeSharedConfig): + super().__init__() + + self.input_size = config.hidden_size + self.hidden_size = config.shared_intermediate_size + self.activation = ACT2FN[config.hidden_act] + self.input_linear = nn.Linear(self.input_size, self.hidden_size * 2, bias=False) + self.output_linear = nn.Linear(self.hidden_size, self.input_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.input_linear(hidden_states) + chunked_hidden_states = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] + hidden_states = self.output_linear(hidden_states) + return hidden_states + + +class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer): + def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[GraniteFlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states * self.residual_multiplier + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + moe_hidden_states = self.block_sparse_moe(hidden_states) + + if self.shared_mlp is None: + hidden_states = moe_hidden_states + else: + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + return hidden_states + + +class GraniteMoeSharedPreTrainedModel(GraniteMoePreTrainedModel): + config: GraniteMoeSharedConfig + _no_split_modules = ["GraniteMoeSharedDecoderLayer"] + + +class GraniteMoeSharedModel(GraniteMoeModel): + def __init__(self, config: GraniteMoeSharedConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [GraniteMoeSharedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + +class GraniteMoeSharedForCausalLM(GraniteMoeForCausalLM): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config: GraniteMoeSharedConfig): + super().__init__(config) + self.model = GraniteMoeSharedModel(config) + # Initialize weights and apply final processing + self.post_init() + + +__all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"] diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/__init__.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2bd053d2c8b9fdaec4211b2b74e964bb88a5e3 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_instructblip import * + from .modeling_instructblip import * + from .processing_instructblip import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/configuration_instructblip.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/configuration_instructblip.py new file mode 100644 index 0000000000000000000000000000000000000000..07fe02fb613b9f439bf95933af4166dfad7d95c4 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/configuration_instructblip.py @@ -0,0 +1,186 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""InstructBLIP model configuration""" + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES +from ...utils import auto_docstring, logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +@auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl") +@strict +class InstructBlipVisionConfig(PreTrainedConfig): + r""" + Example: + + ```python + >>> from transformers import InstructBlipVisionConfig, InstructBlipVisionModel + + >>> # Initializing a InstructBlipVisionConfig with Salesforce/instructblip-flan-t5-xl style configuration + >>> configuration = InstructBlipVisionConfig() + + >>> # Initializing a InstructBlipVisionModel (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration + >>> model = InstructBlipVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "instructblip_vision_model" + base_config_key = "vision_config" + + hidden_size: int = 1408 + intermediate_size: int = 6144 + num_hidden_layers: int = 39 + num_attention_heads: int = 16 + image_size: int | list[int] | tuple[int, int] = 224 + patch_size: int | list[int] | tuple[int, int] = 14 + hidden_act: str = "gelu" + layer_norm_eps: float = 1e-6 + attention_dropout: float | int = 0.0 + initializer_range: float = 1e-10 + qkv_bias: bool = True + + +@auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl") +@strict +class InstructBlipQFormerConfig(PreTrainedConfig): + r""" + cross_attention_frequency (`int`, *optional*, defaults to 2): + The frequency of adding cross-attention to the Transformer layers. + encoder_hidden_size (`int`, *optional*, defaults to 1408): + The hidden size of the hidden states for cross-attention. + + Examples: + + ```python + >>> from transformers import InstructBlipQFormerConfig, InstructBlipQFormerModel + + >>> # Initializing a InstructBLIP Salesforce/instructblip-flan-t5-xl style configuration + >>> configuration = InstructBlipQFormerConfig() + + >>> # Initializing a model (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration + >>> model = InstructBlipQFormerModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "instructblip_qformer" + base_config_key = "qformer_config" + + vocab_size: int = 30522 + hidden_size: int = 768 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + intermediate_size: int = 3072 + hidden_act: str = "gelu" + hidden_dropout_prob: float | int = 0.1 + attention_probs_dropout_prob: float | int = 0.1 + max_position_embeddings: int = 512 + initializer_range: float = 0.02 + layer_norm_eps: float = 1e-12 + pad_token_id: int | None = 0 + cross_attention_frequency: int = 2 + encoder_hidden_size: int = 1408 + + +@auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl") +@strict +class InstructBlipConfig(PreTrainedConfig): + r""" + qformer_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`InstructBlipQFormerConfig`]. + num_query_tokens (`int`, *optional*, defaults to 32): + The number of query tokens passed through the Transformer. + + Example: + + ```python + >>> from transformers import ( + ... InstructBlipVisionConfig, + ... InstructBlipQFormerConfig, + ... OPTConfig, + ... InstructBlipConfig, + ... InstructBlipForConditionalGeneration, + ... ) + + >>> # Initializing a InstructBlipConfig with Salesforce/instructblip-flan-t5-xl style configuration + >>> configuration = InstructBlipConfig() + + >>> # Initializing a InstructBlipForConditionalGeneration (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration + >>> model = InstructBlipForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a InstructBlipConfig from a InstructBlipVisionConfig, InstructBlipQFormerConfig and any PreTrainedConfig + + >>> # Initializing InstructBLIP vision, InstructBLIP Q-Former and language model configurations + >>> vision_config = InstructBlipVisionConfig() + >>> qformer_config = InstructBlipQFormerConfig() + >>> text_config = OPTConfig() + + >>> config = InstructBlipConfig(vision_config=vision_config, qformer_config=qformer_config, text_config=text_config) + ```""" + + model_type = "instructblip" + attribute_map = { + "image_token_id": "image_token_index", + } + sub_configs = { + "text_config": AutoConfig, + "qformer_config": InstructBlipQFormerConfig, + "vision_config": InstructBlipVisionConfig, + } + + vision_config: dict | PreTrainedConfig | None = None + qformer_config: dict | PreTrainedConfig | None = None + text_config: dict | PreTrainedConfig | None = None + num_query_tokens: int = 32 + image_token_index: int | None = None + initializer_factor: float = 1.0 + initializer_range: float = 0.02 + + def __post_init__(self, **kwargs): + if self.text_config is None: + self.text_config = CONFIG_MAPPING["opt"]() + logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).") + elif isinstance(self.text_config, dict): + text_model_type = self.text_config.get("model_type", "opt") + self.text_config = CONFIG_MAPPING[text_model_type](**self.text_config) + + if self.qformer_config is None: + self.qformer_config = InstructBlipQFormerConfig() + logger.info("qformer_config is None. Initializing the InstructBlipQFormerConfig with default values.") + elif isinstance(self.qformer_config, dict): + self.qformer_config = InstructBlipQFormerConfig(**self.qformer_config) + + if self.vision_config is None: + self.vision_config = InstructBlipVisionConfig() + logger.info("`vision_config` is `None`. initializing the `InstructBlipVisionConfig` with default values.") + elif isinstance(self.vision_config, dict): + self.vision_config = InstructBlipVisionConfig(**self.vision_config) + + self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size + self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + super().__post_init__(**kwargs) + + +__all__ = ["InstructBlipConfig", "InstructBlipQFormerConfig", "InstructBlipVisionConfig"] diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/modeling_instructblip.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/modeling_instructblip.py new file mode 100644 index 0000000000000000000000000000000000000000..3eec091c4414a5268d1ecd1660899db47946cdd4 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/modeling_instructblip.py @@ -0,0 +1,1405 @@ +# Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch InstructBLIP model.""" + +import math +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import torch +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...masking_utils import create_bidirectional_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithPast, + Seq2SeqLMOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import apply_chunking_to_forward +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM +from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig + + +logger = logging.get_logger(__name__) + + +@auto_docstring +@dataclass +class BaseModelOutputWithVisionQformerOutputs(BaseModelOutputWithPooling): + r""" + vision_outputs (`BaseModelOutputWithPooling`): + Outputs of the vision encoder. + qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`): + Outputs of the Q-Former (Querying Transformer). + """ + + vision_outputs: BaseModelOutputWithPooling | None = None + qformer_outputs: BaseModelOutputWithPoolingAndCrossAttentions | None = None + + +@auto_docstring( + custom_intro=""" + Class defining the outputs of [`InstructBlipForConditionalGeneration`]. + """ +) +@dataclass +# Copied from transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGenerationModelOutput with Blip2->InstructBlip +class InstructBlipForConditionalGenerationModelOutput(ModelOutput): + r""" + loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Language modeling loss from the language model. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head of the language model. + vision_outputs (`BaseModelOutputWithPooling`): + Outputs of the vision encoder. + qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`): + Outputs of the Q-Former (Querying Transformer). + language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`): + Outputs of the language model. + """ + + loss: tuple[torch.FloatTensor] | None = None + logits: tuple[torch.FloatTensor] | None = None + vision_outputs: BaseModelOutputWithPooling | None = None + qformer_outputs: BaseModelOutputWithPoolingAndCrossAttentions | None = None + language_model_outputs: CausalLMOutputWithPast | Seq2SeqLMOutput | None = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] + if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"] + else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->InstructBlip +class InstructBlipVisionEmbeddings(nn.Module): + def __init__(self, config: InstructBlipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding + + class_pos_embed = self.position_embedding[:, :1] + patch_pos_embed = self.position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + position_embedding = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embedding = self.position_embedding + embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype) + return embeddings + + +# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBLIP doesn't cast attn weights to fp32 +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlip +class InstructBlipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.is_causal = False + self.attention_dropout = config.attention_dropout + + # small tweak here compared to CLIP, no bias here + self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False) + + if config.qkv_bias: + q_bias = nn.Parameter(torch.zeros(self.embed_dim)) + v_bias = nn.Parameter(torch.zeros(self.embed_dim)) + else: + q_bias = None + v_bias = None + + if q_bias is not None: + qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias)) + self.qkv.bias = nn.Parameter(qkv_bias) + + self.projection = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + mixed_qkv = self.qkv(hidden_states) + + mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute( + 2, 0, 3, 1, 4 + ) + query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scale, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.projection(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.blip.modeling_blip.BlipMLP +class InstructBlipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip +class InstructBlipEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: InstructBlipConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = InstructBlipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = InstructBlipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + **kwargs, + ) + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = hidden_states + residual + + return hidden_states + + +@auto_docstring +class InstructBlipPreTrainedModel(PreTrainedModel): + config: InstructBlipConfig + base_model_prefix = "blip" + input_modalities = ("image", "text") + supports_gradient_checkpointing = True + _supports_attention_backend = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + + _no_split_modules = [ + "InstructBlipQFormerEmbeddings", + "InstructBlipAttention", + "InstructBlipQFormerMultiHeadAttention", + "InstructBlipQFormerSelfOutput", + ] + + @torch.no_grad() + def _init_weights(self, module): + """Initialize the weights""" + super()._init_weights(module) + factor = self.config.initializer_range + if isinstance(module, InstructBlipVisionEmbeddings): + init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) + init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) + elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)): + init.zeros_(module.query_tokens) + elif isinstance(module, InstructBlipQFormerEmbeddings): + init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1))) + + +# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlip +class InstructBlipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`InstructBlipEncoderLayer`]. + + Args: + config (`InstructBlipConfig`): + The corresponding vision configuration for the `InstructBlipEncoder`. + """ + + def __init__(self, config: InstructBlipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([InstructBlipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @auto_docstring + def forward( + self, + inputs_embeds, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutput: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + **kwargs, + ) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +class InstructBlipVisionModel(InstructBlipPreTrainedModel): + main_input_name = "pixel_values" + input_modalities = ("image",) + config: InstructBlipVisionConfig + _can_record_outputs = { + "hidden_states": InstructBlipEncoderLayer, + "attentions": InstructBlipAttention, + } + + def __init__(self, config: InstructBlipVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = InstructBlipVisionEmbeddings(config) + self.encoder = InstructBlipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.post_init() + + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor | None = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + + def get_input_embeddings(self): + return self.embeddings + + +class InstructBlipQFormerMultiHeadAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention heads (%d)" + % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size) + self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + **kwargs: Unpack[TransformersKwargs], + ): + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_scores_dtype = attention_scores.dtype + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer, attention_probs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->InstructBlipQFormer +class InstructBlipQFormerSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerAttention with Blip2->InstructBlip +class InstructBlipQFormerAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.attention = InstructBlipQFormerMultiHeadAttention(config, is_cross_attention) + self.output = InstructBlipQFormerSelfOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + attn_output, _ = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + attention_output = self.output(attn_output, hidden_states) + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->InstructBlipQFormer +class InstructBlipQFormerIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->InstructBlipQFormer +class InstructBlipQFormerOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class InstructBlipQFormerLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = InstructBlipQFormerAttention(config) + + self.layer_idx = layer_idx + + if layer_idx % config.cross_attention_frequency == 0: + self.crossattention = InstructBlipQFormerAttention(config, is_cross_attention=True) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = InstructBlipQFormerIntermediate(config) + self.output = InstructBlipQFormerOutput(config) + + self.intermediate_query = InstructBlipQFormerIntermediate(config) + self.output_query = InstructBlipQFormerOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + query_length=0, + **kwargs: Unpack[TransformersKwargs], + ): + attention_output = self.attention( + hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + if encoder_hidden_states is None: + raise ValueError("encoder_hidden_states must be given for cross-attention layers") + query_attention_output = self.crossattention( + query_attention_output, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ).to(layer_output.device) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + return layer_output + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerEncoder with Blip2->InstructBlip +class InstructBlipQFormerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [InstructBlipQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + query_length=0, + **kwargs: Unpack[TransformersKwargs], + ): + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + + hidden_states = layer_module( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + query_length=query_length, + **kwargs, + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + ) + + +class InstructBlipQFormerEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + + position_embeddings = self.position_embeddings(position_ids.to(embeddings.device)) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = embeddings.to(self.layernorm.weight.dtype) + embeddings = self.layernorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class InstructBlipQFormerModel(InstructBlipPreTrainedModel): + """ + Querying Transformer (Q-Former), used in InstructBLIP. Slightly modified from BLIP-2 as it also takes the + instruction as input. + """ + + _supports_attention_backend = False # adds position on attn weights before last matmul + _supports_flash_attn = False + _supports_sdpa = False + _supports_flex_attn = False + + _can_record_outputs = { + "hidden_states": InstructBlipQFormerLayer, + "attentions": [ + OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".attention"), + ], + "cross_attentions": [ + OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".crossattention"), + ], + } + + def __init__(self, config: InstructBlipQFormerConfig): + super().__init__(config) + self.config = config + + self.embeddings = InstructBlipQFormerEmbeddings(config) + + self.encoder = InstructBlipQFormerEncoder(config) + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: torch.FloatTensor | None = None, + position_ids: torch.LongTensor | None = None, + query_embeds: torch.Tensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor] | BaseModelOutputWithPoolingAndCrossAttentions: + r""" + query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Hidden states to be used in the attention computation. If cross-attention, + will be used for the query (i.e., key and value will use the encoder_hidden_states). + """ + if input_ids is None and query_embeds is None: + raise ValueError("You have to specify query_embeds when input_ids is None") + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + ) + + attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=embedding_output, + attention_mask=attention_mask, + ) + + if encoder_attention_mask is not None: + encoder_attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=embedding_output, + attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + ) + + encoder_outputs: BaseModelOutput = self.encoder( + embedding_output, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + query_length=query_length, + **kwargs, + ) + sequence_output = encoder_outputs.last_hidden_state + pooled_output = sequence_output[:, 0, :] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + ) + + +@auto_docstring( + custom_intro=""" + InstructBLIP base Model consisting of language model, qformer and vision encoder. + """ +) +class InstructBlipModel(InstructBlipPreTrainedModel): + main_input_name = "pixel_values" + _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 + + def __init__(self, config: InstructBlipConfig): + super().__init__(config) + + self.vision_model = InstructBlipVisionModel(config.vision_config) + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = InstructBlipQFormerModel(config.qformer_config) + + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + self.language_model = AutoModel.from_config(config.text_config) + + # Initialize weights and apply final processing + self.post_init() + + def _preprocess_accelerate(self): + r""" + Some pre-processing hacks to make the model `accelerate` compatible. Check + https://github.com/huggingface/transformers/pull/21707 for more details. + """ + hf_device_map = self.hf_device_map + + if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: + # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`. + logger.warning( + "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" + " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." + " Please pass a `device_map` that contains `language_model` to remove this warning." + " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for" + " more details on creating a `device_map` for large models.", + ) + + if hasattr(self.language_model, "_hf_hook"): + self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.FloatTensor, + qformer_attention_mask: torch.LongTensor | None = None, + input_ids: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + decoder_input_ids: torch.LongTensor | None = None, + decoder_attention_mask: torch.LongTensor | None = None, + inputs_embeds: torch.Tensor | None = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple | InstructBlipForConditionalGenerationModelOutput: + r""" + qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided + to serve as text prompt, which the Q-Former model will encode. + + Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + Only relevant in case an encoder-decoder language model (like T5) is used. + """ + + # step 1: forward the images through the vision encoder, + # to get image embeddings of shape (batch_size, seq_len, hidden_size) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) + if qformer_attention_mask is None: + qformer_attention_mask = torch.ones_like(qformer_input_ids) + qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) + query_outputs = self.qformer( + input_ids=qformer_input_ids, + attention_mask=qformer_attention_mask, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + **kwargs, + ) + query_output = query_outputs[0][:, : query_tokens.size(1), :] + + if inputs_embeds is None: + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) + + if self.config.use_decoder_only_language_model: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + **kwargs, + ) + else: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + **kwargs, + ) + + return InstructBlipForConditionalGenerationModelOutput( + vision_outputs=vision_outputs, + qformer_outputs=query_outputs, + language_model_outputs=outputs, + ) + + +@auto_docstring( + custom_intro=""" + InstructBLIP Model for generating text given an image and an optional text prompt. The model consists of a vision + encoder, Querying Transformer (Q-Former) and a language model. + + One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue + the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token. + """ +) +class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin): + config: InstructBlipConfig + main_input_name = "pixel_values" + + _can_compile_fullgraph = True + _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 + + def __init__(self, config: InstructBlipConfig): + super().__init__(config) + + self.vision_model = InstructBlipVisionModel._from_config(config.vision_config) + + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + self.qformer = InstructBlipQFormerModel._from_config(config.qformer_config) + + self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) + + if config.use_decoder_only_language_model: + language_model = AutoModelForCausalLM.from_config(config.text_config) + else: + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) + + self.language_model = language_model + + # Initialize weights and apply final processing + self.post_init() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def get_output_embeddings(self) -> nn.Module: + return self.language_model.get_output_embeddings() + + def get_encoder(self, modality=None): + if modality is None: + return self.language_model.get_encoder() + else: + return super().get_encoder(modality=modality) + + def get_decoder(self): + return self.language_model.get_decoder() + + # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._preprocess_accelerate + def _preprocess_accelerate(self): + r""" + Some pre-processing hacks to make the model `accelerate` compatible. Check + https://github.com/huggingface/transformers/pull/21707 for more details. + """ + hf_device_map = self.hf_device_map + + if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: + # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`. + logger.warning( + "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" + " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." + " Please pass a `device_map` that contains `language_model` to remove this warning." + " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for" + " more details on creating a `device_map` for large models.", + ) + + if hasattr(self.language_model, "_hf_hook"): + self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + + @can_return_tuple + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.LongTensor, + qformer_attention_mask: torch.LongTensor | None = None, + interpolate_pos_encoding: bool | None = False, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithVisionQformerOutputs: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided + to serve as text prompt, which the Q-Former model will encode. + + Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + """ + # step 1: forward the images through the vision encoder, + # to get image embeddings of shape (batch_size, seq_len, hidden_size) + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + **kwargs, + ) + vision_outputs = BaseModelOutputWithVisionQformerOutputs(**vision_outputs, vision_outputs=vision_outputs) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) + if qformer_attention_mask is None: + qformer_attention_mask = torch.ones_like(qformer_input_ids) + qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) + qformer_outputs = self.qformer( + input_ids=qformer_input_ids, + attention_mask=qformer_attention_mask, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=True, + **kwargs, + ) + vision_outputs.qformer_outputs = qformer_outputs + query_output = qformer_outputs[0][:, : query_tokens.size(1), :] + + # step 3: use the language model, conditioned on the query outputs and the prompt + image_features = self.language_projection(query_output) + vision_outputs.pooler_output = image_features + + return vision_outputs + + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.FloatTensor, + qformer_attention_mask: torch.LongTensor | None = None, + input_ids: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + decoder_input_ids: torch.LongTensor | None = None, + decoder_attention_mask: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | InstructBlipForConditionalGenerationModelOutput: + r""" + qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided + to serve as text prompt, which the Q-Former model will encode. + + Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + Only relevant in case an encoder-decoder language model (like T5) is used. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size - + 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size]` + + Examples: + + ```python + >>> from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration + >>> import torch + >>> from PIL import Image + >>> import httpx + >>> from io import BytesIO + + >>> model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b") + >>> processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b") + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + >>> model.to(device) # doctest: +IGNORE_RESULT + + >>> url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())).convert("RGB") + >>> prompt = "What is unusual about this image?" + >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device) + + >>> outputs = model.generate( + ... **inputs, + ... do_sample=False, + ... num_beams=5, + ... max_length=256, + ... min_length=1, + ... top_p=0.9, + ... repetition_penalty=1.5, + ... length_penalty=1.0, + ... temperature=1, + ... ) + >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip() + >>> print(generated_text) + The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV, which is parked in the middle of a busy city street. This is an unconventional approach to ironing clothes, as it requires the man to balance himself and his ironing equipment on top of the vehicle while navigating through traffic. Additionally, the presence of taxis and other vehicles in the scene further emphasizes the unusual nature of this situation. + ```""" + + image_features: BaseModelOutputWithVisionQformerOutputs = self.get_image_features( + pixel_values, + qformer_input_ids=qformer_input_ids, + qformer_attention_mask=qformer_attention_mask, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + language_model_inputs = image_features.pooler_output + qformer_outputs = image_features.qformer_outputs + vision_outputs = image_features.vision_outputs + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) + + if self.config.use_decoder_only_language_model: + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + **kwargs, + ) + logits = outputs[0] + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + else: + kwargs["return_dict"] = True + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=labels, + **kwargs, + ) + loss = outputs.loss + logits = outputs.logits + + return InstructBlipForConditionalGenerationModelOutput( + loss=loss, + logits=logits, + vision_outputs=vision_outputs, + qformer_outputs=qformer_outputs, + language_model_outputs=outputs, + ) + + @torch.no_grad() + def generate( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.LongTensor | None = None, + qformer_attention_mask: torch.LongTensor | None = None, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + interpolate_pos_encoding: bool = False, + **generate_kwargs, + ) -> torch.LongTensor: + """ + Overrides `generate` function to be able to use the model as a conditional generator. + + Args: + pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)): + Input images to be processed. + qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): + The sequence used as a prompt to be fed to the Q-Former module. + qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): + Mask to avoid performing attention on padding token indices. + input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): + The sequence used as a prompt for the generation. + attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): + Mask to avoid performing attention on padding token indices. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the positional encoding of the image embeddings. + + Returns: + captions (list): A list of strings of length batch_size * num_captions. + """ + if hasattr(self, "hf_device_map"): + # preprocess for `accelerate` + self._preprocess_accelerate() + + batch_size = pixel_values.shape[0] + image_features: BaseModelOutputWithVisionQformerOutputs = self.get_image_features( + pixel_values, + qformer_input_ids=qformer_input_ids, + qformer_attention_mask=qformer_attention_mask, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + language_model_inputs = image_features.pooler_output + + if inputs_embeds is None: + if input_ids is None: + image_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens = image_tokens + [self.config.text_config.bos_token_id] + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) + input_ids = input_ids.repeat(batch_size, 1) + inputs_embeds = self.get_input_embeddings()(input_ids) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) + + inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} + if not self.language_model.config.is_encoder_decoder: + inputs["input_ids"] = input_ids + + outputs = self.language_model.generate(**inputs, **generate_kwargs) + + return outputs + + +__all__ = [ + "InstructBlipQFormerModel", + "InstructBlipPreTrainedModel", + "InstructBlipModel", + "InstructBlipForConditionalGeneration", + "InstructBlipVisionModel", +] diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/processing_instructblip.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/processing_instructblip.py new file mode 100644 index 0000000000000000000000000000000000000000..10c69dd79d1d4d7b2a65a6f326dbad9a3f40b3be --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/processing_instructblip.py @@ -0,0 +1,123 @@ +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for InstructBLIP. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former. +""" + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput +from ...utils import auto_docstring, logging + + +logger = logging.get_logger(__name__) + + +class InstructBlipProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + } + + +@auto_docstring +class InstructBlipProcessor(ProcessorMixin): + def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs): + r""" + qformer_tokenizer (`AutoTokenizer`): + An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input. + num_query_tokens (`int`, *optional*): + " + Number of tokens used by the Qformer as queries, should be same as in model's config. + """ + if not hasattr(tokenizer, "image_token"): + self.image_token = AddedToken("", normalized=False, special=True) + tokenizer.add_tokens([self.image_token], special_tokens=True) + else: + self.image_token = tokenizer.image_token + self.num_query_tokens = num_query_tokens + + super().__init__(image_processor, tokenizer, qformer_tokenizer) + + @auto_docstring + def __call__( + self, + images: ImageInput | None = None, + text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None, + **kwargs: Unpack[InstructBlipProcessorKwargs], + ) -> BatchFeature: + if images is None and text is None: + raise ValueError("You have to specify at least images or text.") + + output_kwargs = self._merge_kwargs( + InstructBlipProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + encoding = {} + if text is not None: + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"]) + encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids") + encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask") + + # We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token + if output_kwargs["text_kwargs"].get("max_length") is not None: + output_kwargs["text_kwargs"]["max_length"] -= self.num_query_tokens + text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + if images is not None: + # Image tokens should not be padded/truncated or prepended with special BOS token + image_tokens = self.image_token.content * self.num_query_tokens + output_kwargs["text_kwargs"]["add_special_tokens"] = False + output_kwargs["text_kwargs"]["padding"] = False + output_kwargs["text_kwargs"]["truncation"] = False + image_text_encoding = self.tokenizer(image_tokens, **output_kwargs["text_kwargs"]) + for k in text_encoding: + text_encoding[k] = [image_text_encoding[k] + sample for sample in text_encoding[k]] + encoding.update(text_encoding) + + if images is not None: + image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"]) + encoding.update(image_encoding) + + # Cast to desired return tensors type + encoding = BatchFeature(encoding, tensor_type=return_tensors) + return encoding + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + qformer_input_names = ["qformer_input_ids", "qformer_attention_mask"] + return tokenizer_input_names + image_processor_input_names + qformer_input_names + + +__all__ = ["InstructBlipProcessor"] diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mllama/__init__.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mllama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b64f2a8779692c561e05c1e6deb6d39ec927458 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mllama/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mllama import * + from .image_processing_mllama import * + from .image_processing_pil_mllama import * + from .modeling_mllama import * + from .processing_mllama import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_mobilevit.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_mobilevit.py new file mode 100644 index 0000000000000000000000000000000000000000..763360ef82117465be6fa08f8a595755403379f2 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_mobilevit.py @@ -0,0 +1,963 @@ +# Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE +"""PyTorch MobileViT model.""" + +import math + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ... import initialization as init +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, + SemanticSegmenterOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, logging, torch_int +from .configuration_mobilevit import MobileViTConfig + + +logger = logging.get_logger(__name__) + + +def make_divisible(value: int, divisor: int = 8, min_value: int | None = None) -> int: + """ + Ensure that all layers have a channel count that is divisible by `divisor`. + """ + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_value < 0.9 * value: + new_value += divisor + return int(new_value) + + +class MobileViTConvLayer(nn.Module): + def __init__( + self, + config: MobileViTConfig, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + bias: bool = False, + dilation: int = 1, + use_normalization: bool = True, + use_activation: bool | str = True, + ) -> None: + super().__init__() + padding = int((kernel_size - 1) / 2) * dilation + + if in_channels % groups != 0: + raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.") + if out_channels % groups != 0: + raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.") + + self.convolution = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode="zeros", + ) + + if use_normalization: + self.normalization = nn.BatchNorm2d( + num_features=out_channels, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + ) + else: + self.normalization = None + + if use_activation: + if isinstance(use_activation, str): + self.activation = ACT2FN[use_activation] + elif isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + else: + self.activation = None + + def forward(self, features: torch.Tensor) -> torch.Tensor: + features = self.convolution(features) + if self.normalization is not None: + features = self.normalization(features) + if self.activation is not None: + features = self.activation(features) + return features + + +class MobileViTInvertedResidual(nn.Module): + """ + Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381 + """ + + def __init__( + self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1 + ) -> None: + super().__init__() + expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8) + + if stride not in [1, 2]: + raise ValueError(f"Invalid stride {stride}.") + + self.use_residual = (stride == 1) and (in_channels == out_channels) + + self.expand_1x1 = MobileViTConvLayer( + config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1 + ) + + self.conv_3x3 = MobileViTConvLayer( + config, + in_channels=expanded_channels, + out_channels=expanded_channels, + kernel_size=3, + stride=stride, + groups=expanded_channels, + dilation=dilation, + ) + + self.reduce_1x1 = MobileViTConvLayer( + config, + in_channels=expanded_channels, + out_channels=out_channels, + kernel_size=1, + use_activation=False, + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + residual = features + + features = self.expand_1x1(features) + features = self.conv_3x3(features) + features = self.reduce_1x1(features) + + return residual + features if self.use_residual else features + + +class MobileViTMobileNetLayer(nn.Module): + def __init__( + self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1 + ) -> None: + super().__init__() + + self.layer = nn.ModuleList() + for i in range(num_stages): + layer = MobileViTInvertedResidual( + config, + in_channels=in_channels, + out_channels=out_channels, + stride=stride if i == 0 else 1, + ) + self.layer.append(layer) + in_channels = out_channels + + def forward(self, features: torch.Tensor) -> torch.Tensor: + for layer_module in self.layer: + features = layer_module(features) + return features + + +class MobileViTSelfAttention(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int) -> None: + super().__init__() + + if hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size {hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class MobileViTSelfOutput(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int) -> None: + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class MobileViTAttention(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int) -> None: + super().__init__() + self.attention = MobileViTSelfAttention(config, hidden_size) + self.output = MobileViTSelfOutput(config, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + self_outputs = self.attention(hidden_states) + attention_output = self.output(self_outputs) + return attention_output + + +class MobileViTIntermediate(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None: + super().__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class MobileViTOutput(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None: + super().__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class MobileViTTransformerLayer(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None: + super().__init__() + self.attention = MobileViTAttention(config, hidden_size) + self.intermediate = MobileViTIntermediate(config, hidden_size, intermediate_size) + self.output = MobileViTOutput(config, hidden_size, intermediate_size) + self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + attention_output = self.attention(self.layernorm_before(hidden_states)) + hidden_states = attention_output + hidden_states + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output, hidden_states) + return layer_output + + +class MobileViTTransformer(nn.Module): + def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int) -> None: + super().__init__() + + self.layer = nn.ModuleList() + for _ in range(num_stages): + transformer_layer = MobileViTTransformerLayer( + config, + hidden_size=hidden_size, + intermediate_size=int(hidden_size * config.mlp_ratio), + ) + self.layer.append(transformer_layer) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for layer_module in self.layer: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class MobileViTLayer(GradientCheckpointingLayer): + """ + MobileViT block: https://huggingface.co/papers/2110.02178 + """ + + def __init__( + self, + config: MobileViTConfig, + in_channels: int, + out_channels: int, + stride: int, + hidden_size: int, + num_stages: int, + dilation: int = 1, + ) -> None: + super().__init__() + self.patch_width = config.patch_size + self.patch_height = config.patch_size + + if stride == 2: + self.downsampling_layer = MobileViTInvertedResidual( + config, + in_channels=in_channels, + out_channels=out_channels, + stride=stride if dilation == 1 else 1, + dilation=dilation // 2 if dilation > 1 else 1, + ) + in_channels = out_channels + else: + self.downsampling_layer = None + + self.conv_kxk = MobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=in_channels, + kernel_size=config.conv_kernel_size, + ) + + self.conv_1x1 = MobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=hidden_size, + kernel_size=1, + use_normalization=False, + use_activation=False, + ) + + self.transformer = MobileViTTransformer( + config, + hidden_size=hidden_size, + num_stages=num_stages, + ) + + self.layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + + self.conv_projection = MobileViTConvLayer( + config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1 + ) + + self.fusion = MobileViTConvLayer( + config, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size + ) + + def unfolding(self, features: torch.Tensor) -> tuple[torch.Tensor, dict]: + patch_width, patch_height = self.patch_width, self.patch_height + patch_area = int(patch_width * patch_height) + + batch_size, channels, orig_height, orig_width = features.shape + + new_height = ( + torch_int(torch.ceil(orig_height / patch_height) * patch_height) + if torch.jit.is_tracing() + else int(math.ceil(orig_height / patch_height) * patch_height) + ) + new_width = ( + torch_int(torch.ceil(orig_width / patch_width) * patch_width) + if torch.jit.is_tracing() + else int(math.ceil(orig_width / patch_width) * patch_width) + ) + + interpolate = False + if new_width != orig_width or new_height != orig_height: + # Note: Padding can be done, but then it needs to be handled in attention function. + features = nn.functional.interpolate( + features, size=(new_height, new_width), mode="bilinear", align_corners=False + ) + interpolate = True + + # number of patches along width and height + num_patch_width = new_width // patch_width + num_patch_height = new_height // patch_height + num_patches = num_patch_height * num_patch_width + + # convert from shape (batch_size, channels, orig_height, orig_width) + # to the shape (batch_size * patch_area, num_patches, channels) + patches = features.reshape( + batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width + ) + patches = patches.transpose(1, 2) + patches = patches.reshape(batch_size, channels, num_patches, patch_area) + patches = patches.transpose(1, 3) + patches = patches.reshape(batch_size * patch_area, num_patches, -1) + + info_dict = { + "orig_size": (orig_height, orig_width), + "batch_size": batch_size, + "channels": channels, + "interpolate": interpolate, + "num_patches": num_patches, + "num_patches_width": num_patch_width, + "num_patches_height": num_patch_height, + } + return patches, info_dict + + def folding(self, patches: torch.Tensor, info_dict: dict) -> torch.Tensor: + patch_width, patch_height = self.patch_width, self.patch_height + patch_area = int(patch_width * patch_height) + + batch_size = info_dict["batch_size"] + channels = info_dict["channels"] + num_patches = info_dict["num_patches"] + num_patch_height = info_dict["num_patches_height"] + num_patch_width = info_dict["num_patches_width"] + + # convert from shape (batch_size * patch_area, num_patches, channels) + # back to shape (batch_size, channels, orig_height, orig_width) + features = patches.contiguous().view(batch_size, patch_area, num_patches, -1) + features = features.transpose(1, 3) + features = features.reshape( + batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width + ) + features = features.transpose(1, 2) + features = features.reshape( + batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width + ) + + if info_dict["interpolate"]: + features = nn.functional.interpolate( + features, size=info_dict["orig_size"], mode="bilinear", align_corners=False + ) + + return features + + def forward(self, features: torch.Tensor) -> torch.Tensor: + # reduce spatial dimensions if needed + if self.downsampling_layer: + features = self.downsampling_layer(features) + + residual = features + + # local representation + features = self.conv_kxk(features) + features = self.conv_1x1(features) + + # convert feature map to patches + patches, info_dict = self.unfolding(features) + + # learn global representations + patches = self.transformer(patches) + patches = self.layernorm(patches) + + # convert patches back to feature maps + features = self.folding(patches, info_dict) + + features = self.conv_projection(features) + features = self.fusion(torch.cat((residual, features), dim=1)) + return features + + +class MobileViTEncoder(nn.Module): + def __init__(self, config: MobileViTConfig) -> None: + super().__init__() + self.config = config + + self.layer = nn.ModuleList() + self.gradient_checkpointing = False + + # segmentation architectures like DeepLab and PSPNet modify the strides + # of the classification backbones + dilate_layer_4 = dilate_layer_5 = False + if config.output_stride == 8: + dilate_layer_4 = True + dilate_layer_5 = True + elif config.output_stride == 16: + dilate_layer_5 = True + + dilation = 1 + + layer_1 = MobileViTMobileNetLayer( + config, + in_channels=config.neck_hidden_sizes[0], + out_channels=config.neck_hidden_sizes[1], + stride=1, + num_stages=1, + ) + self.layer.append(layer_1) + + layer_2 = MobileViTMobileNetLayer( + config, + in_channels=config.neck_hidden_sizes[1], + out_channels=config.neck_hidden_sizes[2], + stride=2, + num_stages=3, + ) + self.layer.append(layer_2) + + layer_3 = MobileViTLayer( + config, + in_channels=config.neck_hidden_sizes[2], + out_channels=config.neck_hidden_sizes[3], + stride=2, + hidden_size=config.hidden_sizes[0], + num_stages=2, + ) + self.layer.append(layer_3) + + if dilate_layer_4: + dilation *= 2 + + layer_4 = MobileViTLayer( + config, + in_channels=config.neck_hidden_sizes[3], + out_channels=config.neck_hidden_sizes[4], + stride=2, + hidden_size=config.hidden_sizes[1], + num_stages=4, + dilation=dilation, + ) + self.layer.append(layer_4) + + if dilate_layer_5: + dilation *= 2 + + layer_5 = MobileViTLayer( + config, + in_channels=config.neck_hidden_sizes[4], + out_channels=config.neck_hidden_sizes[5], + stride=2, + hidden_size=config.hidden_sizes[2], + num_stages=3, + dilation=dilation, + ) + self.layer.append(layer_5) + + def forward( + self, + hidden_states: torch.Tensor, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> tuple | BaseModelOutputWithNoAttention: + all_hidden_states = () if output_hidden_states else None + + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states) + + +@auto_docstring +class MobileViTPreTrainedModel(PreTrainedModel): + config: MobileViTConfig + base_model_prefix = "mobilevit" + main_input_name = "pixel_values" + input_modalities = ("image",) + supports_gradient_checkpointing = True + _no_split_modules = ["MobileViTLayer"] + + @torch.no_grad() + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + init.zeros_(module.bias) + if getattr(module, "running_mean", None) is not None: + init.zeros_(module.running_mean) + init.ones_(module.running_var) + init.zeros_(module.num_batches_tracked) + elif isinstance(module, nn.LayerNorm): + init.zeros_(module.bias) + init.ones_(module.weight) + + +@auto_docstring +class MobileViTModel(MobileViTPreTrainedModel): + def __init__(self, config: MobileViTConfig, expand_output: bool = True): + r""" + expand_output (`bool`, *optional*, defaults to `True`): + Whether to expand the output of the model using a 1x1 convolution. If `True`, the model will apply an additional + 1x1 convolution to expand the output channels from `config.neck_hidden_sizes[5]` to `config.neck_hidden_sizes[6]`. + """ + super().__init__(config) + self.config = config + self.expand_output = expand_output + + self.conv_stem = MobileViTConvLayer( + config, + in_channels=config.num_channels, + out_channels=config.neck_hidden_sizes[0], + kernel_size=3, + stride=2, + ) + + self.encoder = MobileViTEncoder(config) + + if self.expand_output: + self.conv_1x1_exp = MobileViTConvLayer( + config, + in_channels=config.neck_hidden_sizes[5], + out_channels=config.neck_hidden_sizes[6], + kernel_size=1, + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | BaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.conv_stem(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.expand_output: + last_hidden_state = self.conv_1x1_exp(encoder_outputs[0]) + + # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels) + pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False) + else: + last_hidden_state = encoder_outputs[0] + pooled_output = None + + if not return_dict: + output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,) + return output + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """ +) +class MobileViTForImageClassification(MobileViTPreTrainedModel): + def __init__(self, config: MobileViTConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.mobilevit = MobileViTModel(config) + + # Classifier head + self.dropout = nn.Dropout(config.classifier_dropout_prob, inplace=True) + self.classifier = ( + nn.Linear(config.neck_hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + labels: torch.Tensor | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | ImageClassifierOutputWithNoAttention: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + outputs = self.mobilevit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(self.dropout(pooled_output)) + + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +class MobileViTASPPPooling(nn.Module): + def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int) -> None: + super().__init__() + + self.global_pool = nn.AdaptiveAvgPool2d(output_size=1) + + self.conv_1x1 = MobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + use_normalization=True, + use_activation="relu", + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + spatial_size = features.shape[-2:] + features = self.global_pool(features) + features = self.conv_1x1(features) + features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False) + return features + + +class MobileViTASPP(nn.Module): + """ + ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587 + """ + + def __init__(self, config: MobileViTConfig) -> None: + super().__init__() + + in_channels = config.neck_hidden_sizes[-2] + out_channels = config.aspp_out_channels + + if len(config.atrous_rates) != 3: + raise ValueError("Expected 3 values for atrous_rates") + + self.convs = nn.ModuleList() + + in_projection = MobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + use_activation="relu", + ) + self.convs.append(in_projection) + + self.convs.extend( + [ + MobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + dilation=rate, + use_activation="relu", + ) + for rate in config.atrous_rates + ] + ) + + pool_layer = MobileViTASPPPooling(config, in_channels, out_channels) + self.convs.append(pool_layer) + + self.project = MobileViTConvLayer( + config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu" + ) + + self.dropout = nn.Dropout(p=config.aspp_dropout_prob) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + pyramid = [] + for conv in self.convs: + pyramid.append(conv(features)) + pyramid = torch.cat(pyramid, dim=1) + + pooled_features = self.project(pyramid) + pooled_features = self.dropout(pooled_features) + return pooled_features + + +class MobileViTDeepLabV3(nn.Module): + """ + DeepLabv3 architecture: https://huggingface.co/papers/1706.05587 + """ + + def __init__(self, config: MobileViTConfig) -> None: + super().__init__() + self.aspp = MobileViTASPP(config) + + self.dropout = nn.Dropout2d(config.classifier_dropout_prob) + + self.classifier = MobileViTConvLayer( + config, + in_channels=config.aspp_out_channels, + out_channels=config.num_labels, + kernel_size=1, + use_normalization=False, + use_activation=False, + bias=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + features = self.aspp(hidden_states[-1]) + features = self.dropout(features) + features = self.classifier(features) + return features + + +@auto_docstring( + custom_intro=""" + MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC. + """ +) +class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel): + def __init__(self, config: MobileViTConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.mobilevit = MobileViTModel(config, expand_output=False) + self.segmentation_head = MobileViTDeepLabV3(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | SemanticSegmenterOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> import httpx + >>> from io import BytesIO + >>> import torch + >>> from PIL import Image + >>> from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + + >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small") + >>> model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # logits are of shape (batch_size, num_labels, height, width) + >>> logits = outputs.logits + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if labels is not None and self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + + outputs = self.mobilevit( + pixel_values, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + logits = self.segmentation_head(encoder_hidden_states) + + loss = None + if labels is not None: + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) + loss = loss_fct(upsampled_logits, labels) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) + + +__all__ = [ + "MobileViTForImageClassification", + "MobileViTForSemanticSegmentation", + "MobileViTModel", + "MobileViTPreTrainedModel", +] diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/configuration_speecht5.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/configuration_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..82646d9f8927f0e848cb2ce4f9a4ecc6f810ab63 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/configuration_speecht5.py @@ -0,0 +1,279 @@ +# Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SpeechT5 model configuration""" + +import functools +import operator + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="microsoft/speecht5_asr") +@strict +class SpeechT5Config(PreTrainedConfig): + r""" + positional_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the text position encoding layers. + feat_extract_norm (`str`, *optional*, defaults to `"group"`): + The norm to be applied to 1D convolutional layers in the speech encoder pre-net. One of `"group"` for group + normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D + convolutional layers. + feat_proj_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for output of the speech encoder pre-net. + feat_extract_activation (`str, `optional`, defaults to `"gelu"`): + The non-linear activation function (function or string) in the 1D convolutional layers of the feature + extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. + conv_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + speech encoder pre-net. The length of *conv_dim* defines the number of 1D convolutional layers. + conv_stride (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`): + A tuple of integers defining the stride of each 1D convolutional layer in the speech encoder pre-net. The + length of *conv_stride* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the speech encoder pre-net. + The length of *conv_kernel* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer. + apply_spec_augment (`bool`, *optional*, defaults to `True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the speech encoder pre-net. For + reference see [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://huggingface.co/papers/1904.08779). + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procedure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If + reasoning from the probability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2),: + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procedure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over + the axis. If reasoning from the probability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0),: + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks'' + num_mel_bins (`int`, *optional*, defaults to 80): + Number of mel features used per input features. Used by the speech decoder pre-net. Should correspond to + the value used in the [`SpeechT5Processor`] class. + speech_decoder_prenet_layers (`int`, *optional*, defaults to 2): + Number of layers in the speech decoder pre-net. + speech_decoder_prenet_units (`int`, *optional*, defaults to 256): + Dimensionality of the layers in the speech decoder pre-net. + speech_decoder_prenet_dropout (`float`, *optional*, defaults to 0.5): + The dropout probability for the speech decoder pre-net layers. + speaker_embedding_dim (`int`, *optional*, defaults to 512): + Dimensionality of the *XVector* embedding vectors. + speech_decoder_postnet_layers (`int`, *optional*, defaults to 5): + Number of layers in the speech decoder post-net. + speech_decoder_postnet_units (`int`, *optional*, defaults to 256): + Dimensionality of the layers in the speech decoder post-net. + speech_decoder_postnet_kernel (`int`, *optional*, defaults to 5): + Number of convolutional filter channels in the speech decoder post-net. + speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5): + The dropout probability for the speech decoder post-net layers. + reduction_factor (`int`, *optional*, defaults to 2): + Spectrogram length reduction factor for the speech decoder inputs. + max_speech_positions (`int`, *optional*, defaults to 4000): + The maximum sequence length of speech features that this model might ever be used with. + max_text_positions (`int`, *optional*, defaults to 450): + The maximum sequence length of text features that this model might ever be used with. + encoder_max_relative_position (`int`, *optional*, defaults to 160): + Maximum distance for relative position embedding in the encoder. + use_guided_attention_loss (`bool`, *optional*, defaults to `True`): + Whether to apply guided attention loss while training the TTS model. + guided_attention_loss_num_heads (`int`, *optional*, defaults to 2): + Number of attention heads the guided attention loss will be applied to. Use -1 to apply this loss to all + attention heads. + guided_attention_loss_sigma (`float`, *optional*, defaults to 0.4): + Standard deviation for guided attention loss. + guided_attention_loss_scale (`float`, *optional*, defaults to 10.0): + Scaling coefficient for guided attention loss (also known as lambda). + + Example: + + ```python + >>> from transformers import SpeechT5Model, SpeechT5Config + + >>> # Initializing a "microsoft/speecht5_asr" style configuration + >>> configuration = SpeechT5Config() + + >>> # Initializing a model (with random weights) from the "microsoft/speecht5_asr" style configuration + >>> model = SpeechT5Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "speecht5" + attribute_map = {"num_attention_heads": "encoder_attention_heads", "num_hidden_layers": "encoder_layers"} + + vocab_size: int = 81 + hidden_size: int = 768 + encoder_layers: int = 12 + encoder_attention_heads: int = 12 + encoder_ffn_dim: int = 3072 + encoder_layerdrop: float | int = 0.1 + decoder_layers: int = 6 + decoder_ffn_dim: int = 3072 + decoder_attention_heads: int = 12 + decoder_layerdrop: float | int = 0.1 + hidden_act: str = "gelu" + positional_dropout: float | int = 0.1 + hidden_dropout: float | int = 0.1 + attention_dropout: float | int = 0.1 + activation_dropout: float | int = 0.1 + initializer_range: float = 0.02 + layer_norm_eps: float = 1e-5 + scale_embedding: bool = False + feat_extract_norm: str = "group" + feat_proj_dropout: float | int = 0.0 + feat_extract_activation: str = "gelu" + conv_dim: list[int] | tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512) + conv_stride: list[int] | tuple[int, ...] = (5, 2, 2, 2, 2, 2, 2) + conv_kernel: list[int] | tuple[int, ...] = (10, 3, 3, 3, 3, 2, 2) + conv_bias: bool = False + num_conv_pos_embeddings: int = 128 + num_conv_pos_embedding_groups: int = 16 + apply_spec_augment: bool = True + mask_time_prob: float | int = 0.05 + mask_time_length: int = 10 + mask_time_min_masks: int = 2 + mask_feature_prob: float | int = 0.0 + mask_feature_length: int = 10 + mask_feature_min_masks: int = 0 + pad_token_id: int | None = 1 + bos_token_id: int | None = 0 + eos_token_id: int | list[int] | None = 2 + decoder_start_token_id: int | None = 2 + num_mel_bins: int = 80 + speech_decoder_prenet_layers: int = 2 + speech_decoder_prenet_units: int = 256 + speech_decoder_prenet_dropout: float | int = 0.5 + speaker_embedding_dim: int = 512 + speech_decoder_postnet_layers: int = 5 + speech_decoder_postnet_units: int = 256 + speech_decoder_postnet_kernel: int = 5 + speech_decoder_postnet_dropout: float | int = 0.5 + reduction_factor: int = 2 + max_speech_positions: int = 4000 + max_text_positions: int = 450 + encoder_max_relative_position: int = 160 + use_guided_attention_loss: bool = True + guided_attention_loss_num_heads: int = 2 + guided_attention_loss_sigma: float = 0.4 + guided_attention_loss_scale: float = 10.0 + use_cache: bool = True + is_encoder_decoder: bool = True + tie_word_embeddings: bool = True + + def __post_init__(self, **kwargs): + self.num_feat_extract_layers = len(self.conv_dim) + super().__post_init__(**kwargs) + + def validate_architecture(self): + """Part of `@strict`-powered validation. Validates the architecture of the config.""" + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` ==" + " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) =" + f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`," + f" `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + def inputs_to_logits_ratio(self): + return functools.reduce(operator.mul, self.conv_stride, 1) + + +@auto_docstring(checkpoint="microsoft/speecht5_asr") +@strict +class SpeechT5HifiGanConfig(PreTrainedConfig): + r""" + model_in_dim (`int`, *optional*, defaults to 80): + The number of frequency bins in the input log-mel spectrogram. + upsample_initial_channel (`int`, *optional*, defaults to 512): + The number of input channels into the upsampling network. + upsample_rates (`tuple[int]` or `list[int]`, *optional*, defaults to `[4, 4, 4, 4]`): + A tuple of integers defining the stride of each 1D convolutional layer in the upsampling network. The + length of *upsample_rates* defines the number of convolutional layers and has to match the length of + *upsample_kernel_sizes*. + upsample_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[8, 8, 8, 8]`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the upsampling network. The + length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match the length of + *upsample_rates*. + resblock_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[3, 7, 11]`): + A tuple of integers defining the kernel sizes of the 1D convolutional layers in the multi-receptive field + fusion (MRF) module. + resblock_dilation_sizes (`tuple[tuple[int]]` or `list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the + multi-receptive field fusion (MRF) module. + leaky_relu_slope (`float`, *optional*, defaults to 0.1): + The angle of the negative slope used by the leaky ReLU activation. + normalize_before (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the spectrogram before vocoding using the vocoder's learned mean and variance. + + Example: + + ```python + >>> from transformers import SpeechT5HifiGan, SpeechT5HifiGanConfig + + >>> # Initializing a "microsoft/speecht5_hifigan" style configuration + >>> configuration = SpeechT5HifiGanConfig() + + >>> # Initializing a model (with random weights) from the "microsoft/speecht5_hifigan" style configuration + >>> model = SpeechT5HifiGan(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "speecht5_hifigan" + + model_in_dim: int = 80 + sampling_rate: int = 16000 + upsample_initial_channel: int = 512 + upsample_rates: list[int] | tuple[int, ...] = (4, 4, 4, 4) + upsample_kernel_sizes: list[int] | tuple[int, ...] = (8, 8, 8, 8) + resblock_kernel_sizes: list[int] | tuple[int, ...] = (3, 7, 11) + resblock_dilation_sizes: list | tuple = ((1, 3, 5), (1, 3, 5), (1, 3, 5)) + initializer_range: float = 0.01 + leaky_relu_slope: float = 0.1 + normalize_before: bool = True + + +__all__ = ["SpeechT5Config", "SpeechT5HifiGanConfig"] diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/modeling_speecht5.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/modeling_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..f75fa9dcdcd9e9b9de5d7b5e2ceaeb2c9038de6f --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/modeling_speecht5.py @@ -0,0 +1,3095 @@ +# Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SpeechT5 model.""" + +import math + +import numpy as np +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...masking_utils import create_bidirectional_mask, create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqSpectrogramOutput, +) +from ...modeling_utils import EmbeddingAccessMixin, PreTrainedModel +from ...utils import auto_docstring, logging +from .configuration_speecht5 import SpeechT5Config, SpeechT5HifiGanConfig + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 1 + +# General docstring + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def shift_spectrograms_right( + input_values: torch.Tensor, reduction_factor: int = 1, attention_mask: torch.Tensor | None = None +): + """ + Shift input spectrograms one timestep to the right. Also applies the reduction factor to the sequence length. + """ + # thin out frames for reduction factor + if reduction_factor > 1: + input_values = input_values[:, reduction_factor - 1 :: reduction_factor] + if attention_mask is not None: + attention_mask = attention_mask[:, reduction_factor - 1 :: reduction_factor] + + shifted_input_values = input_values.new_zeros(input_values.shape) + shifted_input_values[:, 1:] = input_values[:, :-1].clone() + + # replace possible -100 values in labels by zeros + shifted_input_values.masked_fill_(shifted_input_values == -100.0, 0.0) + + return shifted_input_values, attention_mask + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: torch.LongTensor | None = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SpeechT5 +class SpeechT5NoLayerNormConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SpeechT5 +class SpeechT5LayerNormConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SpeechT5 +class SpeechT5GroupNormConvLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->SpeechT5 +class SpeechT5SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: int | None = None): + super().__init__() + self.offset = 2 + self.num_positions = num_positions + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.register_buffer("weights", emb_weights, persistent=False) + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None): + """ + Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the + description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() + + def create_position_ids_from_input_ids( + self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: int | None = 0 + ): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->SpeechT5 +class SpeechT5PositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = SpeechT5SamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class SpeechT5ScaledPositionalEncoding(nn.Module): + """ + Scaled positional encoding, see §3.2 in https://huggingface.co/papers/1809.08895 + """ + + def __init__(self, dropout, dim, max_len=5000): + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / dim)) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0) + super().__init__() + self.register_buffer("pe", pe, persistent=False) + self.dropout = nn.Dropout(p=dropout) + self.dim = dim + self.max_len = max_len + self.alpha = nn.Parameter(torch.tensor(1.0)) + + def forward(self, emb): + emb = emb + self.alpha * self.pe[:, : emb.size(1)] + emb = self.dropout(emb) + return emb + + +class SpeechT5RelativePositionalEncoding(torch.nn.Module): + def __init__(self, dim, max_length=1000): + super().__init__() + self.dim = dim + self.max_length = max_length + self.pe_k = torch.nn.Embedding(2 * max_length, dim) + + def forward(self, hidden_states): + seq_len = hidden_states.shape[1] + pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long) + pos_seq = pos_seq[:, None] - pos_seq[None, :] + + pos_seq = torch.where(pos_seq < -self.max_length, -self.max_length, pos_seq) + pos_seq = torch.where(pos_seq >= self.max_length, self.max_length - 1, pos_seq) + pos_seq = pos_seq + self.max_length + + return self.pe_k(pos_seq) + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SpeechT5 +class SpeechT5SamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SpeechT5 +class SpeechT5FeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [SpeechT5GroupNormConvLayer(config, layer_id=0)] + [ + SpeechT5NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [ + SpeechT5LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) + ] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->SpeechT5 +class SpeechT5FeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states, norm_hidden_states + + +class SpeechT5SpeechEncoderPrenet(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.feature_encoder = SpeechT5FeatureEncoder(config) + self.feature_projection = SpeechT5FeatureProjection(config) + + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.pos_conv_embed = SpeechT5PositionalConvEmbedding(config) + self.pos_sinusoidal_embed = SpeechT5SinusoidalPositionalEmbedding( + config.max_speech_positions + config.pad_token_id + 1, + config.hidden_size, + config.pad_token_id, + ) + + def freeze_feature_encoder(self): + self.feature_encoder._freeze_parameters() + + def forward( + self, + input_values: torch.Tensor, + attention_mask: torch.LongTensor | None = None, + mask_time_indices: torch.FloatTensor | None = None, + ): + extract_features = self.feature_encoder(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], + attention_mask, + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + positional_conv_embedding = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + positional_conv_embedding + + if attention_mask is not None: + padding_mask = attention_mask.ne(1).long() + else: + padding_mask = torch.zeros(hidden_states.shape[:2], dtype=torch.long, device=hidden_states.device) + + positional_sinusoidal_embeddings = self.pos_sinusoidal_embed(padding_mask) + hidden_states = hidden_states + positional_sinusoidal_embeddings + + return hidden_states, attention_mask + + # Copied from transformers.models.unispeech.modeling_unispeech.UniSpeechPreTrainedModel._get_feature_vector_attention_mask + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + # Copied from transformers.models.unispeech.modeling_unispeech.UniSpeechPreTrainedModel._get_feat_extract_output_lengths + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://huggingface.co/papers/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + +class SpeechT5SpeechDecoderPrenet(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.layers = nn.ModuleList( + [ + nn.Linear( + config.num_mel_bins if i == 0 else config.speech_decoder_prenet_units, + config.speech_decoder_prenet_units, + ) + for i in range(config.speech_decoder_prenet_layers) + ] + ) + + self.final_layer = nn.Linear(config.speech_decoder_prenet_units, config.hidden_size) + self.encode_positions = SpeechT5ScaledPositionalEncoding( + config.positional_dropout, + config.hidden_size, + config.max_speech_positions, + ) + self.speaker_embeds_layer = nn.Linear(config.speaker_embedding_dim + config.hidden_size, config.hidden_size) + + def _consistent_dropout(self, inputs_embeds, p): + mask = torch.bernoulli(inputs_embeds[0], p=p) + all_masks = mask.unsqueeze(0).repeat(inputs_embeds.size(0), 1, 1) + return torch.where(all_masks == 1, inputs_embeds, 0) * 1 / (1 - p) + + def forward( + self, + input_values: torch.Tensor, + speaker_embeddings: torch.Tensor | None = None, + ): + # Dropout is always applied, even when evaluating. See §2.2 in https://huggingface.co/papers/1712.05884. + + inputs_embeds = input_values + for layer in self.layers: + inputs_embeds = nn.functional.relu(layer(inputs_embeds)) + inputs_embeds = self._consistent_dropout(inputs_embeds, self.config.speech_decoder_prenet_dropout) + + inputs_embeds = self.final_layer(inputs_embeds) + inputs_embeds = self.encode_positions(inputs_embeds) + + if speaker_embeddings is not None: + speaker_embeddings = nn.functional.normalize(speaker_embeddings) + speaker_embeddings = speaker_embeddings.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1) + inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1) + inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds)) + + return inputs_embeds + + +class SpeechT5BatchNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + + if layer_id == 0: + in_conv_dim = config.num_mel_bins + else: + in_conv_dim = config.speech_decoder_postnet_units + + if layer_id == config.speech_decoder_postnet_layers - 1: + out_conv_dim = config.num_mel_bins + else: + out_conv_dim = config.speech_decoder_postnet_units + + self.conv = nn.Conv1d( + in_conv_dim, + out_conv_dim, + kernel_size=config.speech_decoder_postnet_kernel, + stride=1, + padding=(config.speech_decoder_postnet_kernel - 1) // 2, + bias=False, + ) + self.batch_norm = nn.BatchNorm1d(out_conv_dim) + + if layer_id < config.speech_decoder_postnet_layers - 1: + self.activation = nn.Tanh() + else: + self.activation = None + + self.dropout = nn.Dropout(config.speech_decoder_postnet_dropout) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.batch_norm(hidden_states) + if self.activation is not None: + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class SpeechT5SpeechDecoderPostnet(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.feat_out = nn.Linear(config.hidden_size, config.num_mel_bins * config.reduction_factor) + self.prob_out = nn.Linear(config.hidden_size, config.reduction_factor) + + self.layers = nn.ModuleList( + [SpeechT5BatchNormConvLayer(config, i) for i in range(config.speech_decoder_postnet_layers)] + ) + + def forward(self, hidden_states: torch.Tensor): + outputs_before_postnet = self.feat_out(hidden_states).view(hidden_states.size(0), -1, self.config.num_mel_bins) + outputs_after_postnet = self.postnet(outputs_before_postnet) + logits = self.prob_out(hidden_states).view(hidden_states.size(0), -1) + return outputs_before_postnet, outputs_after_postnet, logits + + def postnet(self, hidden_states: torch.Tensor): + layer_output = hidden_states.transpose(1, 2) + for layer in self.layers: + layer_output = layer(layer_output) + return hidden_states + layer_output.transpose(1, 2) + + +class SpeechT5TextEncoderPrenet(nn.Module, EmbeddingAccessMixin): + def __init__(self, config): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.encode_positions = SpeechT5ScaledPositionalEncoding( + config.positional_dropout, + config.hidden_size, + config.max_text_positions, + ) + + def forward(self, input_ids: torch.Tensor): + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.encode_positions(inputs_embeds) + return inputs_embeds + + +class SpeechT5TextDecoderPrenet(nn.Module, EmbeddingAccessMixin): + def __init__(self, config): + super().__init__() + self.config = config + self.dropout = nn.Dropout(config.positional_dropout) + self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + + self.embed_positions = SpeechT5SinusoidalPositionalEmbedding( + config.max_text_positions + config.pad_token_id + 1, + config.hidden_size, + config.pad_token_id, + ) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + ): + if input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + else: + raise ValueError("You have to specify `decoder_input_ids`") + + past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length() + positions = self.embed_positions(input_ids, past_key_values_length) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds += positions + inputs_embeds = self.dropout(inputs_embeds) + + return inputs_embeds, attention_mask + + +class SpeechT5TextDecoderPostnet(nn.Module, EmbeddingAccessMixin): + def __init__(self, config): + super().__init__() + self.config = config + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward(self, hidden_states: torch.Tensor): + return self.lm_head(hidden_states) + + def get_output_embeddings(self): + # Post-net has no token embeddings, but its lm_head must still be + # tied to the decoder weights when `tie_word_embeddings=True`. + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + +class SpeechT5Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper with relative position bias (see + https://aclanthology.org/N18-2074.pdf) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float | None = 0.0, + is_decoder: bool | None = False, + bias: bool | None = True, + layer_idx: bool | None = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor | None = None, + past_key_values: Cache | None = None, + attention_mask: torch.Tensor | None = None, + position_bias: torch.Tensor | None = None, + output_attentions: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_values = past_key_values.cross_attention_cache + else: + curr_past_key_values = past_key_values.self_attention_cache + else: + curr_past_key_values = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_values.layers[self.layer_idx].keys + value_states = curr_past_key_values.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # relative attention bias + if position_bias is not None: + reshape_q = query_states.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0, 1) + rel_pos_bias = torch.matmul(reshape_q, position_bias.transpose(-2, -1)) + rel_pos_bias = rel_pos_bias.transpose(0, 1).view( + bsz * self.num_heads, position_bias.size(0), position_bias.size(1) + ) + attn_weights += rel_pos_bias + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class SpeechT5FeedForward(nn.Module): + def __init__(self, config, intermediate_size): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class SpeechT5EncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: SpeechT5Config): + super().__init__() + self.attention = SpeechT5Attention( + embed_dim=config.hidden_size, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = SpeechT5FeedForward(config, config.encoder_ffn_dim) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_bias: torch.Tensor | None = None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, hidden_size)` + attention_mask (`torch.FloatTensor`): + attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very + large negative values. + position_bias (`torch.FloatTensor`): + relative position embeddings of size `(seq_len, seq_len, hidden_size // encoder_attention_heads)` + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + ) + + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SpeechT5DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: SpeechT5Config, layer_idx=None): + super().__init__() + self.self_attn = SpeechT5Attention( + embed_dim=config.hidden_size, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.encoder_attn = SpeechT5Attention( + config.hidden_size, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.feed_forward = SpeechT5FeedForward(config, config.decoder_ffn_dim) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = True, + **kwargs, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, hidden_size)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + ) + hidden_states = self.dropout(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +@auto_docstring +class SpeechT5PreTrainedModel(PreTrainedModel): + config: SpeechT5Config + base_model_prefix = "speecht5" + main_input_name = "input_values" + input_modalities = "audio" + supports_gradient_checkpointing = True + + @torch.no_grad() + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, SpeechT5PositionalConvEmbedding): + init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + init.constant_(module.conv.bias, 0) + elif isinstance(module, SpeechT5ScaledPositionalEncoding): + init.ones_(module.alpha) + dim, max_len = module.dim, module.max_len + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / dim)) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0) + init.copy_(module.pe, pe) + elif isinstance(module, SpeechT5FeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + init.uniform_(module.projection.weight, a=-k, b=k) + init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)): + init.zeros_(module.bias) + init.ones_(module.weight) + if getattr(module, "running_mean", None) is not None: + init.zeros_(module.running_mean) + init.ones_(module.running_var) + init.zeros_(module.num_batches_tracked) + elif isinstance(module, nn.Conv1d): + init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + init.uniform_(module.bias, a=-k, b=k) + elif isinstance(module, nn.Embedding): + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) + elif isinstance(module, SpeechT5SinusoidalPositionalEmbedding): + emb_weights = module.get_embedding( + module.num_positions + module.offset, module.embedding_dim, module.padding_idx + ) + init.copy_(module.weights, emb_weights) + elif isinstance(module, SpeechT5HifiGan): + init.zeros_(module.mean) + init.ones_(module.scale) + + if hasattr(module, "masked_spec_embed"): + init.uniform_(module.masked_spec_embed) + + +class SpeechT5Encoder(SpeechT5PreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* layers. Each layer is a [`SpeechT5EncoderLayer`]. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layerdrop = config.encoder_layerdrop + + self.layers = nn.ModuleList([SpeechT5EncoderLayer(config) for _ in range(config.encoder_layers)]) + + self.embed_positions = SpeechT5RelativePositionalEncoding( + config.hidden_size // config.encoder_attention_heads, config.encoder_max_relative_position + ) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | BaseModelOutput: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`): + Features extracted from the speech or text input by the encoder prenet. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=attention_mask, + ) + + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + position_bias = self.embed_positions(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + skip_the_layer = False + if self.training: + dropout_probability = torch.rand([]) + skip_the_layer = dropout_probability < self.layerdrop + + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SpeechT5EncoderWithSpeechPrenet(SpeechT5PreTrainedModel): + """ + Wrapper around SpeechT5Encoder that applies SpeechT5SpeechEncoderPrenet to convert the audio waveform data to + hidden features. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.prenet = SpeechT5SpeechEncoderPrenet(config) + self.wrapped_encoder = SpeechT5Encoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_values: torch.FloatTensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | BaseModelOutput: + hidden_states, attention_mask = self.prenet(input_values, attention_mask) + + outputs = self.wrapped_encoder( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return outputs + + +class SpeechT5EncoderWithTextPrenet(SpeechT5PreTrainedModel): + """ + Wrapper around SpeechT5Encoder that applies SpeechT5TextEncoderPrenet to convert the input_ids to hidden features. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.prenet = SpeechT5TextEncoderPrenet(config) + self.wrapped_encoder = SpeechT5Encoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.prenet.get_input_embeddings() + + def set_input_embeddings(self, value): + self.prenet.set_input_embeddings(value) + + def forward( + self, + input_values: torch.FloatTensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | BaseModelOutput: + hidden_states = self.prenet(input_values) + + outputs = self.wrapped_encoder( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return outputs + + +class SpeechT5EncoderWithoutPrenet(SpeechT5PreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when used in combination with + [`SpeechT5Model`]. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.wrapped_encoder = SpeechT5Encoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_values: torch.FloatTensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | BaseModelOutput: + return self.wrapped_encoder( + hidden_states=input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SpeechT5Decoder(SpeechT5PreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`SpeechT5DecoderLayer`] + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.layerdrop = config.decoder_layerdrop + + self.layers = nn.ModuleList([SpeechT5DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + hidden_states: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | BaseModelOutputWithPastAndCrossAttentions: + r""" + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`): + Features extracted from the speech or text input by the decoder prenet. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + + attention_mask = create_causal_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + encoder_attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + ) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + skip_the_layer = False + if self.training: + dropout_probability = torch.rand([]) + skip_the_layer = dropout_probability < self.layerdrop + if skip_the_layer and not synced_gpus: + continue + + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class SpeechT5DecoderWithSpeechPrenet(SpeechT5PreTrainedModel): + """ + Wrapper around SpeechT5Decoder that applies SpeechT5SpeechDecoderPrenet to convert log-mel filterbanks to hidden + features. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.prenet = SpeechT5SpeechDecoderPrenet(config) + self.wrapped_decoder = SpeechT5Decoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_values: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.LongTensor | None = None, + speaker_embeddings: torch.Tensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | BaseModelOutputWithPastAndCrossAttentions: + decoder_hidden_states = self.prenet(input_values, speaker_embeddings) + + outputs = self.wrapped_decoder( + hidden_states=decoder_hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return outputs + + +class SpeechT5DecoderWithTextPrenet(SpeechT5PreTrainedModel): + """ + Wrapper around SpeechT5Decoder that applies SpeechT5TextDecoderPrenet to convert input tokens to hidden features. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.prenet = SpeechT5TextDecoderPrenet(config) + self.wrapped_decoder = SpeechT5Decoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.prenet.get_input_embeddings() + + def set_input_embeddings(self, value): + self.prenet.set_input_embeddings(value) + + def forward( + self, + input_values: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | BaseModelOutputWithPastAndCrossAttentions: + decoder_hidden_states, attention_mask = self.prenet(input_values, attention_mask, past_key_values) + + outputs = self.wrapped_decoder( + hidden_states=decoder_hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return outputs + + +class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when used in combination with + [`SpeechT5Model`]. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + self.wrapped_decoder = SpeechT5Decoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_values: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple | BaseModelOutputWithPastAndCrossAttentions: + outputs = self.wrapped_decoder( + hidden_states=input_values, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return outputs + + +class SpeechT5GuidedMultiheadAttentionLoss(nn.Module): + """ + Guided attention loss from the paper [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional + Networks with Guided Attention](https://huggingface.co/papers/1710.08969), adapted for multi-head attention. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__() + self.sigma = config.guided_attention_loss_sigma + self.scale = config.guided_attention_loss_scale + + def forward( + self, attentions: torch.FloatTensor, input_masks: torch.BoolTensor, output_masks: torch.BoolTensor + ) -> torch.Tensor: + """ + Compute the attention loss. + + Args: + attentions (`torch.FloatTensor` of shape `(batch_size, layers * heads, output_sequence_length, input_sequence_length)`): + Batch of multi-head attention weights + input_masks (`torch.BoolTensor` of shape `(batch_size, input_sequence_length)`): + Input attention mask as booleans. + output_masks (`torch.BoolTensor` of shape `(batch_size, output_sequence_length)`): + Target attention mask as booleans. + + Returns: + `torch.Tensor` with the loss value + """ + guided_attn_masks = self._make_guided_attention_masks(input_masks, output_masks, attentions.device) + masks = output_masks.unsqueeze(-1) & input_masks.unsqueeze(-2) + masks = masks.to(attentions.device).unsqueeze(1) + + losses = guided_attn_masks * attentions + loss = torch.mean(losses.masked_select(masks)) + return self.scale * loss + + def _make_guided_attention_masks(self, input_masks, output_masks, device): + input_lengths = input_masks.sum(-1) + output_lengths = output_masks.sum(-1) + + guided_attn_masks = torch.zeros((len(input_masks), output_masks.shape[1], input_masks.shape[1]), device=device) + + for idx, (ilen, olen) in enumerate(zip(input_lengths, output_lengths)): + guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma, device) + + return guided_attn_masks.unsqueeze(1) + + @staticmethod + def _make_guided_attention_mask(input_length, output_length, sigma, device): + grid_y, grid_x = torch.meshgrid( + torch.arange(input_length, device=device), + torch.arange(output_length, device=device), + indexing="xy", + ) + grid_x = grid_x.float() / output_length + grid_y = grid_y.float() / input_length + return 1.0 - torch.exp(-((grid_y - grid_x) ** 2) / (2 * (sigma**2))) + + +class SpeechT5SpectrogramLoss(nn.Module): + """ + Loss computation used by SpeechT5ForTextToSpeech. + """ + + def __init__(self, config: SpeechT5Config): + super().__init__() + self.use_guided_attention_loss = config.use_guided_attention_loss + self.guided_attention_loss_num_heads = config.guided_attention_loss_num_heads + self.reduction_factor = config.reduction_factor + + self.l1_criterion = L1Loss() + self.bce_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(5.0)) + + if self.use_guided_attention_loss: + self.attn_criterion = SpeechT5GuidedMultiheadAttentionLoss(config) + + def forward( + self, + attention_mask: torch.LongTensor, + outputs_before_postnet: torch.FloatTensor, + outputs_after_postnet: torch.FloatTensor, + logits: torch.FloatTensor, + labels: torch.FloatTensor, + cross_attentions: torch.FloatTensor | None = None, + ) -> torch.Tensor: + padding_mask = labels != -100.0 + + # mask out the padded portions + labels = labels.masked_select(padding_mask) + outputs_before_postnet = outputs_before_postnet.masked_select(padding_mask) + outputs_after_postnet = outputs_after_postnet.masked_select(padding_mask) + + # spectrogram loss + l1_loss = self.l1_criterion(outputs_after_postnet, labels) + self.l1_criterion(outputs_before_postnet, labels) + + # construct stop labels from the padding mask + masks = padding_mask[:, :, 0] + stop_labels = torch.cat([~masks * 1.0, torch.ones(masks.size(0), 1).to(masks.device)], dim=1) + stop_labels = stop_labels[:, 1:].masked_select(masks) + logits = logits.masked_select(masks) + + # stop token loss + bce_loss = self.bce_criterion(logits, stop_labels) + + # combined loss + loss = l1_loss + bce_loss + + # guided attention loss + if self.use_guided_attention_loss: + attn = torch.cat([x[:, : self.guided_attention_loss_num_heads] for x in cross_attentions], dim=1) + input_masks = attention_mask == 1 + output_masks = padding_mask[:, :, 0] + if self.reduction_factor > 1: + output_masks = output_masks[:, self.reduction_factor - 1 :: self.reduction_factor] + attn_loss = self.attn_criterion(attn, input_masks, output_masks) + loss += attn_loss + + return loss + + +@auto_docstring( + custom_intro=""" + The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets. + """ +) +class SpeechT5Model(SpeechT5PreTrainedModel): + def __init__( + self, + config: SpeechT5Config, + encoder: nn.Module | None = None, + decoder: nn.Module | None = None, + ): + r""" + encoder (`PreTrainedModel`, *optional*): + The encoder model to use. + decoder (`PreTrainedModel`, *optional*): + The decoder model to use. + """ + super().__init__(config) + self.config = config + self.encoder = SpeechT5EncoderWithoutPrenet(config) if encoder is None else encoder + self.decoder = SpeechT5DecoderWithoutPrenet(config) if decoder is None else decoder + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet): + return self.encoder.get_input_embeddings() + if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet): + return self.decoder.get_input_embeddings() + raise NotImplementedError + + def set_input_embeddings(self, value): + if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet): + self.encoder.set_input_embeddings(value) + if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet): + self.decoder.set_input_embeddings(value) + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + if isinstance(self.encoder, SpeechT5EncoderWithSpeechPrenet): + self.encoder.prenet.freeze_feature_encoder() + + @auto_docstring + def forward( + self, + input_values: torch.Tensor | None = None, + attention_mask: torch.LongTensor | None = None, + decoder_input_values: torch.Tensor | None = None, + decoder_attention_mask: torch.LongTensor | None = None, + encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + speaker_embeddings: torch.FloatTensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs, + ) -> tuple[torch.FloatTensor] | Seq2SeqModelOutput: + r""" + input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`): + Depending on which encoder is being used, the `input_values` are either: float values of the input raw + speech waveform, or indices of input sequence tokens in the vocabulary, or hidden states. + decoder_input_values (`torch.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Depending on which decoder is being used, the `decoder_input_values` are either: float values of log-mel + filterbank features extracted from the raw speech waveform, or indices of decoder input sequence tokens in + the vocabulary, or hidden states. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_values=input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # downsample encoder attention mask (only for encoders with speech input) + if attention_mask is not None and isinstance(self.encoder, SpeechT5EncoderWithSpeechPrenet): + encoder_attention_mask = self.encoder.prenet._get_feature_vector_attention_mask( + encoder_outputs[0].shape[1], attention_mask + ) + else: + encoder_attention_mask = attention_mask + + if isinstance(self.decoder, SpeechT5DecoderWithSpeechPrenet): + decoder_args = {"speaker_embeddings": speaker_embeddings} + else: + decoder_args = {} + + decoder_outputs = self.decoder( + input_values=decoder_input_values, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **decoder_args, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + SpeechT5 Model with a speech encoder and a text decoder. + """ +) +class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"text_decoder_postnet.lm_head.weight": "speecht5.decoder.prenet.embed_tokens.weight"} + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that does not define the" + " vocabulary size of the language model head. Please instantiate the model as follows:" + " `SpeechT5ForSpeechToText.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of" + " your model's configuration." + ) + + speech_encoder = SpeechT5EncoderWithSpeechPrenet(config) + text_decoder = SpeechT5DecoderWithTextPrenet(config) + self.speecht5 = SpeechT5Model(config, speech_encoder, text_decoder) + + self.text_decoder_postnet = SpeechT5TextDecoderPostnet(config) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.get_encoder().prenet.freeze_feature_encoder() + + def get_output_embeddings(self): + return self.text_decoder_postnet.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.text_decoder_postnet.set_output_embeddings(new_embeddings) + + @auto_docstring + def forward( + self, + input_values: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + decoder_input_ids: torch.LongTensor | None = None, + decoder_attention_mask: torch.LongTensor | None = None, + encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: torch.LongTensor | None = None, + **kwargs, + ) -> tuple | Seq2SeqLMOutput: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file + into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library + (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding + and conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details. + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + SpeechT5 uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Label indices can be obtained using [`SpeechT5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + Example: + + ```python + >>> from transformers import SpeechT5Processor, SpeechT5ForSpeechToText + >>> from datasets import load_dataset + + >>> dataset = load_dataset( + ... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation" + ... ) # doctest: +IGNORE_RESULT + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_asr") + >>> model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr") + + >>> # audio file is decoded on the fly + >>> inputs = processor(audio=dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + >>> predicted_ids = model.generate(**inputs, max_length=100) + + >>> # transcribe speech + >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + >>> transcription[0] + 'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel' + ``` + + ```python + >>> inputs["labels"] = processor(text_target=dataset[0]["text"], return_tensors="pt").input_ids + + >>> # compute loss + >>> loss = model(**inputs).loss + >>> round(loss.item(), 2) + 19.68 + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.speecht5( + input_values=input_values, + attention_mask=attention_mask, + decoder_input_values=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + logits = self.text_decoder_postnet(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +def _generate_speech( + model: SpeechT5PreTrainedModel, + input_values: torch.FloatTensor, + speaker_embeddings: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: nn.Module | None = None, + output_cross_attentions: bool = False, + return_output_lengths: bool = False, +) -> torch.FloatTensor | tuple[torch.FloatTensor, torch.FloatTensor]: + if speaker_embeddings is None: + raise ValueError( + """`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following + the code snippet provided in this link: + https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors + """ + ) + + if attention_mask is None: + encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int() + else: + encoder_attention_mask = attention_mask + + bsz = input_values.size(0) + + encoder_out = model.speecht5.encoder( + input_values=input_values, + attention_mask=encoder_attention_mask, + return_dict=True, + ) + + encoder_last_hidden_state = encoder_out.last_hidden_state + + # downsample encoder attention mask + if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet): + encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask( + encoder_out[0].shape[1], encoder_attention_mask + ) + + maxlen = int(encoder_last_hidden_state.size(1) * maxlenratio / model.config.reduction_factor) + minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor) + + # Start the output sequence with a mel spectrum that is all zeros. + output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins) + + spectrogram = [] + cross_attentions = [] + past_key_values = None + idx = 0 + result_spectrogram = {} + + while True: + idx += 1 + + # Run the decoder prenet on the entire output sequence. + decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings) + # Run the decoder layers on the last element of the prenet output. + decoder_out = model.speecht5.decoder.wrapped_decoder( + hidden_states=decoder_hidden_states[:, -1:], + attention_mask=None, + encoder_hidden_states=encoder_last_hidden_state, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=True, + output_attentions=output_cross_attentions, + return_dict=True, + ) + + if output_cross_attentions: + cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0)) + + last_decoder_output = decoder_out.last_hidden_state.squeeze(1) + past_key_values = decoder_out.past_key_values + + # Predict the new mel spectrum for this step in the sequence. + spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output) + spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins) + spectrogram.append(spectrum) + + # Extend the output sequence with the new mel spectrum. + new_spectrogram = spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins) + output_sequence = torch.cat((output_sequence, new_spectrogram), dim=1) + # Predict the probability that this is the stop token. + prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output)) + + if idx < minlen: + continue + else: + # If the generation loop is less than maximum length time, check the ones in the batch that have met + # the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch. + if idx < maxlen: + meet_thresholds = torch.sum(prob, dim=-1) >= threshold + meet_indexes = torch.where(meet_thresholds)[0].tolist() + else: + meet_indexes = range(len(prob)) + meet_indexes = [i for i in meet_indexes if i not in result_spectrogram] + if len(meet_indexes) > 0: + spectrograms = torch.stack(spectrogram) + spectrograms = spectrograms.transpose(0, 1).flatten(1, 2) + spectrograms = model.speech_decoder_postnet.postnet(spectrograms) + for meet_index in meet_indexes: + result_spectrogram[meet_index] = spectrograms[meet_index] + if len(result_spectrogram) >= bsz: + break + spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))] + if not return_output_lengths: + spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + if vocoder is not None: + outputs = vocoder(spectrogram) + else: + outputs = spectrogram + if output_cross_attentions: + cross_attentions = torch.cat(cross_attentions, dim=2) + if bsz > 1: + cross_attentions = cross_attentions.view( + bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] + ) + outputs = (outputs, cross_attentions) + else: + # batched return values should also include the spectrogram/waveform lengths + spectrogram_lengths = [] + for i in range(bsz): + spectrogram_lengths.append(spectrograms[i].size(0)) + if vocoder is None: + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + outputs = (spectrograms, spectrogram_lengths) + else: + waveforms = [] + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + waveforms = vocoder(spectrograms) + waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths] + outputs = (waveforms, waveform_lengths) + if output_cross_attentions: + cross_attentions = torch.cat(cross_attentions, dim=2) + cross_attentions = cross_attentions.view( + bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:] + ) + outputs = (*outputs, cross_attentions) + return outputs + + +@auto_docstring( + custom_intro=""" + SpeechT5 Model with a text encoder and a speech decoder. + """ +) +class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel): + input_modalities = ("text",) + main_input_name = "input_ids" + + def __init__(self, config: SpeechT5Config): + super().__init__(config) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that does not define the" + " vocabulary size of the language model head. Please instantiate the model as follows:" + " `SpeechT5ForTextToSpeech.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of" + " your model's configuration." + ) + + text_encoder = SpeechT5EncoderWithTextPrenet(config) + speech_decoder = SpeechT5DecoderWithSpeechPrenet(config) + self.speecht5 = SpeechT5Model(config, text_encoder, speech_decoder) + + self.speech_decoder_postnet = SpeechT5SpeechDecoderPostnet(config) + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def can_generate(cls) -> bool: + # Speecht5 has a unique model structure, where the external class (`SpeechT5ForTextToSpeech`) doesn't need to inherit from + # `GenerationMixin` (it has a non-standard generation method). This means that the base `can_generate()` will return `False`, + # but we need to override it so as to do `GenerationConfig` handling in multiple parts of the codebase. + return True + + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.LongTensor | None = None, + decoder_input_values: torch.FloatTensor | None = None, + decoder_attention_mask: torch.LongTensor | None = None, + encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + speaker_embeddings: torch.FloatTensor | None = None, + labels: torch.FloatTensor | None = None, + stop_labels: torch.Tensor | None = None, + **kwargs, + ) -> tuple | Seq2SeqSpectrogramOutput: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and + [`~PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`): + Float values of input mel spectrogram. + + SpeechT5 uses an all-zero spectrum as the starting token for `decoder_input_values` generation. If + `past_key_values` is used, optionally only the last `decoder_input_values` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*): + Float values of target mel spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss + computation. Spectrograms can be obtained using [`SpeechT5Processor`]. See [`SpeechT5Processor.__call__`] + for details. + stop_labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Binary tensor indicating the position of the stop token in the sequence. + + Example: + + ```python + >>> from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, set_seed + >>> import torch + + >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") + >>> model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") + >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") + + >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt") + >>> speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file + + >>> set_seed(555) # make deterministic + + >>> # generate speech + >>> speech = model.generate(inputs["input_ids"], speaker_embeddings=speaker_embeddings, vocoder=vocoder) + >>> speech.shape + torch.Size([15872]) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if labels is not None: + if decoder_input_values is None: + decoder_input_values, decoder_attention_mask = shift_spectrograms_right( + labels, self.config.reduction_factor, decoder_attention_mask + ) + if self.config.use_guided_attention_loss: + output_attentions = True + + outputs = self.speecht5( + input_values=input_ids, + attention_mask=attention_mask, + decoder_input_values=decoder_input_values, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + speaker_embeddings=speaker_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + outputs_before_postnet, outputs_after_postnet, logits = self.speech_decoder_postnet(outputs[0]) + + loss = None + if labels is not None: + criterion = SpeechT5SpectrogramLoss(self.config) + loss = criterion( + attention_mask, + outputs_before_postnet, + outputs_after_postnet, + logits, + labels, + outputs.cross_attentions, + ) + + if not return_dict: + output = (outputs_after_postnet,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSpectrogramOutput( + loss=loss, + spectrogram=outputs_after_postnet, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: torch.LongTensor, + attention_mask: torch.LongTensor | None = None, + speaker_embeddings: torch.FloatTensor | None = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: nn.Module | None = None, + output_cross_attentions: bool = False, + return_output_lengths: bool = False, + **kwargs, + ) -> torch.FloatTensor | tuple[torch.FloatTensor, torch.FloatTensor]: + r""" + Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a + speech waveform using a vocoder. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and + [`~PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Attention mask from the tokenizer, required for batched inference to signal to the model where to + ignore padded tokens from the input_ids. + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + threshold (`float`, *optional*, defaults to 0.5): + The generated sequence ends when the predicted stop token probability exceeds this value. + minlenratio (`float`, *optional*, defaults to 0.0): + Used to calculate the minimum required length for the output sequence. + maxlenratio (`float`, *optional*, defaults to 20.0): + Used to calculate the maximum allowed length for the output sequence. + vocoder (`nn.Module`, *optional*): + The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel + spectrogram. + output_cross_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of the decoder's cross-attention layers. + return_output_lengths (`bool`, *optional*, defaults to `False`): + Whether or not to return the concrete spectrogram/waveform lengths. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: + - when `return_output_lengths` is False + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is True + - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that + are padded to the maximum length. + - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of + all the concrete lengths for each spectrogram. + - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length. + - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all + the concrete lengths for each waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + """ + if speaker_embeddings is not None: + batch_size = input_ids.size(0) + if speaker_embeddings.size(0) != batch_size: + if speaker_embeddings.size(0) == 1: + speaker_embeddings = speaker_embeddings.repeat(batch_size, 1) + else: + raise ValueError( + "The first dimension of speaker_embeddings must be either 1 or the same as batch_size." + ) + + return _generate_speech( + self, + input_ids, + speaker_embeddings, + attention_mask, + threshold, + minlenratio, + maxlenratio, + vocoder, + output_cross_attentions, + return_output_lengths, + ) + + @torch.no_grad() + def generate_speech( + self, + input_ids: torch.LongTensor, + speaker_embeddings: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: nn.Module | None = None, + output_cross_attentions: bool = False, + return_output_lengths: bool = False, + ) -> torch.FloatTensor | tuple[torch.FloatTensor, torch.FloatTensor]: + r""" + Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a + speech waveform using a vocoder. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and + [`~PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + threshold (`float`, *optional*, defaults to 0.5): + The generated sequence ends when the predicted stop token probability exceeds this value. + minlenratio (`float`, *optional*, defaults to 0.0): + Used to calculate the minimum required length for the output sequence. + maxlenratio (`float`, *optional*, defaults to 20.0): + Used to calculate the maximum allowed length for the output sequence. + vocoder (`nn.Module`, *optional*, defaults to `None`): + The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel + spectrogram. + output_cross_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of the decoder's cross-attention layers. + return_output_lengths (`bool`, *optional*, defaults to `False`): + Whether or not to return the concrete spectrogram/waveform lengths. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: + - when `return_output_lengths` is False + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is True + - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that + are padded to the maximum length. + - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of + all the concrete lengths for each spectrogram. + - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length. + - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all + the concrete lengths for each waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + """ + if speaker_embeddings is not None: + batch_size = input_ids.size(0) + if speaker_embeddings.size(0) != batch_size: + if speaker_embeddings.size(0) == 1: + speaker_embeddings = speaker_embeddings.repeat(batch_size, 1) + else: + raise ValueError( + "The first dimension of speaker_embeddings must be either 1 or the same as batch size." + ) + + return _generate_speech( + self, + input_ids, + speaker_embeddings, + attention_mask, + threshold, + minlenratio, + maxlenratio, + vocoder, + output_cross_attentions, + return_output_lengths, + ) + + +@auto_docstring( + custom_intro=""" + SpeechT5 Model with a speech encoder and a speech decoder. + """ +) +class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): + def __init__(self, config: SpeechT5Config): + super().__init__(config) + + speech_encoder = SpeechT5EncoderWithSpeechPrenet(config) + speech_decoder = SpeechT5DecoderWithSpeechPrenet(config) + self.speecht5 = SpeechT5Model(config, speech_encoder, speech_decoder) + + self.speech_decoder_postnet = SpeechT5SpeechDecoderPostnet(config) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.get_encoder().prenet.freeze_feature_encoder() + + @auto_docstring + def forward( + self, + input_values: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + decoder_input_values: torch.FloatTensor | None = None, + decoder_attention_mask: torch.LongTensor | None = None, + encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + speaker_embeddings: torch.FloatTensor | None = None, + labels: torch.FloatTensor | None = None, + stop_labels: torch.Tensor | None = None, + **kwargs, + ) -> tuple | Seq2SeqSpectrogramOutput: + r""" + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file + into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library + (`pip install torchcodec`) or the soundfile library (`pip install soundfile`). + To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding and conversion into + a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details. + decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`): + Float values of input mel spectrogram. + + SpeechT5 uses an all-zero spectrum as the starting token for `decoder_input_values` generation. If + `past_key_values` is used, optionally only the last `decoder_input_values` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more + information on the default strategy. + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*): + Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See + [`SpeechT5Processor.__call__`] for details. + stop_labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Binary tensor indicating the position of the stop token in the sequence. + + Example: + + ```python + >>> from transformers import SpeechT5Processor, SpeechT5ForSpeechToSpeech, SpeechT5HifiGan, set_seed + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset( + ... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation" + ... ) # doctest: +IGNORE_RESULT + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_vc") + >>> model = SpeechT5ForSpeechToSpeech.from_pretrained("microsoft/speecht5_vc") + >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") + + >>> # audio file is decoded on the fly + >>> inputs = processor(audio=dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + + >>> speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file + + >>> set_seed(555) # make deterministic + + >>> # generate speech + >>> speech = model.generate_speech(inputs["input_values"], speaker_embeddings, vocoder=vocoder) + >>> speech.shape + torch.Size([77824]) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if labels is not None: + if decoder_input_values is None: + decoder_input_values, decoder_attention_mask = shift_spectrograms_right( + labels, self.config.reduction_factor, decoder_attention_mask + ) + + outputs = self.speecht5( + input_values=input_values, + attention_mask=attention_mask, + decoder_input_values=decoder_input_values, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + speaker_embeddings=speaker_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + _, spectrogram, logits = self.speech_decoder_postnet(outputs[0]) + + loss = None + + if not return_dict: + output = (spectrogram,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSpectrogramOutput( + loss=loss, + spectrogram=spectrogram, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @torch.no_grad() + def generate_speech( + self, + input_values: torch.FloatTensor, + speaker_embeddings: torch.FloatTensor | None = None, + attention_mask: torch.LongTensor | None = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 20.0, + vocoder: nn.Module | None = None, + output_cross_attentions: bool = False, + return_output_lengths: bool = False, + ) -> torch.FloatTensor: + r""" + Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a + speech waveform using a vocoder. + + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. + + Values can be obtained by loading a *.flac* or *.wav* audio file into an array of type `list[float]`, + a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) + or the soundfile library (`pip install soundfile`). + To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details. + speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*): + Tensor containing the speaker embeddings. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + threshold (`float`, *optional*, defaults to 0.5): + The generated sequence ends when the predicted stop token probability exceeds this value. + minlenratio (`float`, *optional*, defaults to 0.0): + Used to calculate the minimum required length for the output sequence. + maxlenratio (`float`, *optional*, defaults to 20.0): + Used to calculate the maximum allowed length for the output sequence. + vocoder (`nn.Module`, *optional*, defaults to `None`): + The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel + spectrogram. + output_cross_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of the decoder's cross-attention layers. + return_output_lengths (`bool`, *optional*, defaults to `False`): + Whether or not to return the concrete spectrogram/waveform lengths. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements depending on the inputs: + - when `return_output_lengths` is False + - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram. + - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(num_frames,)` -- The predicted speech waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + - when `return_output_lengths` is True + - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that + are padded to the maximum length. + - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of + all the concrete lengths for each spectrogram. + - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape + `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length. + - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all + the concrete lengths for each waveform. + - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`) + `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads, + output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers. + """ + if speaker_embeddings is None: + speaker_embeddings = torch.zeros((1, 512), device=input_values.device) + + return _generate_speech( + self, + input_values, + speaker_embeddings, + attention_mask, + threshold, + minlenratio, + maxlenratio, + vocoder, + output_cross_attentions, + return_output_lengths, + ) + + +class HifiGanResidualBlock(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): + super().__init__() + self.leaky_relu_slope = leaky_relu_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=dilation[i], + padding=self.get_padding(kernel_size, dilation[i]), + ) + for i in range(len(dilation)) + ] + ) + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + for _ in range(len(dilation)) + ] + ) + + def get_padding(self, kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + for layer in self.convs1: + weight_norm(layer) + for layer in self.convs2: + weight_norm(layer) + + def remove_weight_norm(self): + for layer in self.convs1: + nn.utils.remove_weight_norm(layer) + for layer in self.convs2: + nn.utils.remove_weight_norm(layer) + + def forward(self, hidden_states): + for conv1, conv2 in zip(self.convs1, self.convs2): + residual = hidden_states + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv1(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +@auto_docstring( + custom_intro=""" + HiFi-GAN vocoder. + """ +) +class SpeechT5HifiGan(PreTrainedModel): + config: SpeechT5HifiGanConfig + main_input_name = "spectrogram" + + def __init__(self, config: SpeechT5HifiGanConfig): + super().__init__(config) + self.num_kernels = len(config.resblock_kernel_sizes) + self.num_upsamples = len(config.upsample_rates) + self.conv_pre = nn.Conv1d( + config.model_in_dim, + config.upsample_initial_channel, + kernel_size=7, + stride=1, + padding=3, + ) + + self.upsampler = nn.ModuleList() + for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): + self.upsampler.append( + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**i), + config.upsample_initial_channel // (2 ** (i + 1)), + kernel_size=kernel_size, + stride=upsample_rate, + padding=(kernel_size - upsample_rate) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.upsampler)): + channels = config.upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): + self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) + + self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3) + + self.register_buffer("mean", torch.zeros(config.model_in_dim)) + self.register_buffer("scale", torch.ones(config.model_in_dim)) + + # Initialize weights and apply final processing + self.post_init() + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, SpeechT5HifiGan): + init.zeros_(module.mean) + init.ones_(module.scale) + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv_pre) + for layer in self.upsampler: + weight_norm(layer) + for layer in self.resblocks: + layer.apply_weight_norm() + weight_norm(self.conv_post) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv_pre) + for layer in self.upsampler: + nn.utils.remove_weight_norm(layer) + for layer in self.resblocks: + layer.remove_weight_norm() + nn.utils.remove_weight_norm(self.conv_post) + + @auto_docstring( + custom_intro=""" + Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch + of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech + waveform. + """ + ) + def forward(self, spectrogram: torch.FloatTensor, **kwargs) -> torch.FloatTensor: + r""" + spectrogram (`torch.FloatTensor`): + Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length, + config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`. + + Returns: + `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of + shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`. + """ + if self.config.normalize_before: + spectrogram = (spectrogram - self.mean) / self.scale + + is_batched = spectrogram.dim() == 3 + if not is_batched: + spectrogram = spectrogram.unsqueeze(0) + + hidden_states = spectrogram.transpose(2, 1) + + hidden_states = self.conv_pre(hidden_states) + for i in range(self.num_upsamples): + hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope) + hidden_states = self.upsampler[i](hidden_states) + + res_state = self.resblocks[i * self.num_kernels](hidden_states) + for j in range(1, self.num_kernels): + res_state += self.resblocks[i * self.num_kernels + j](hidden_states) + hidden_states = res_state / self.num_kernels + + hidden_states = nn.functional.leaky_relu(hidden_states) + hidden_states = self.conv_post(hidden_states) + hidden_states = torch.tanh(hidden_states) + + if not is_batched: + # remove batch dim and collapse tensor to 1-d audio waveform + waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1) + else: + # remove seq-len dim since this collapses to 1 + waveform = hidden_states.squeeze(1) + + return waveform + + +__all__ = [ + "SpeechT5ForSpeechToText", + "SpeechT5ForSpeechToSpeech", + "SpeechT5ForTextToSpeech", + "SpeechT5Model", + "SpeechT5PreTrainedModel", + "SpeechT5HifiGan", +] diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/number_normalizer.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/number_normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..77da92fa608b183d14b86416507d69e31e57890b --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/number_normalizer.py @@ -0,0 +1,191 @@ +# Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Number Normalizer class for SpeechT5.""" + +import re + + +class EnglishNumberNormalizer: + def __init__(self): + self.ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] + self.teens = [ + "", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", + ] + self.tens = ["", "ten", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"] + self.thousands = [ + "", + "thousand", + "million", + "billion", + "trillion", + "quadrillion", + "quintillion", + "sextillion", + "septillion", + "octillion", + "nonillion", + "decillion", + ] + + # Define a dictionary to map currency symbols to their names + # Top most traded currencies according to + # https://en.wikipedia.org/wiki/Template:Most_traded_currencies + self.currency_symbols = { + "$": " dollars", + "€": " euros", + "£": " pounds", + "¢": " cents", + "¥": " japanese yen", + "﷼": " saudi riyal", + "₹": " indian rupees", + "₽": " russian rubles", + "฿": " thai baht", + "₺": " turkish liras", + "₴": " ukrainian hryvnia", + "₣": " swiss francs", + "₡": " costa rican colon", + "₱": " philippine peso", + "₪": " israeli shekels", + "₮": " mongolian tögrög", + "₩": " south korean won", + "₦": " nigerian naira", + "₫": " vietnamese Đồng", + } + + def spell_number(self, num): + if num == 0: + return "zero" + + parts = [] + for i in range(0, len(self.thousands)): + if num % 1000 != 0: + part = "" + hundreds = num % 1000 // 100 + tens_units = num % 100 + + if hundreds > 0: + part += self.ones[hundreds] + " hundred" + if tens_units > 0: + part += " and " + + if tens_units > 10 and tens_units < 20: + part += self.teens[tens_units - 10] + else: + tens_digit = self.tens[tens_units // 10] + ones_digit = self.ones[tens_units % 10] + if tens_digit: + part += tens_digit + if ones_digit: + if tens_digit: + part += " " + part += ones_digit + + parts.append(part) + + num //= 1000 + + return " ".join(reversed(parts)) + + def convert(self, number): + """ + Converts an individual number passed in string form to spelt-out form + """ + if "." in number: + integer_part, decimal_part = number.split(".") + else: + integer_part, decimal_part = number, "00" + + # Extract currency symbol if present + currency_symbol = "" + for symbol, name in self.currency_symbols.items(): + if integer_part.startswith(symbol): + currency_symbol = name + integer_part = integer_part[len(symbol) :] + break + + if integer_part.startswith("-"): + if integer_part[1:].startswith(symbol): + currency_symbol = name + integer_part = "-" + integer_part[len(symbol) + 1 :] + break + + # Extract 'minus' prefix for negative numbers + minus_prefix = "" + if integer_part.startswith("-"): + minus_prefix = "minus " + integer_part = integer_part[1:] + elif integer_part.startswith("minus"): + minus_prefix = "minus " + integer_part = integer_part[len("minus") :] + + percent_suffix = "" + if "%" in integer_part or "%" in decimal_part: + percent_suffix = " percent" + integer_part = integer_part.replace("%", "") + decimal_part = decimal_part.replace("%", "") + + integer_part = integer_part.zfill(3 * ((len(integer_part) - 1) // 3 + 1)) + + parts = [] + for i in range(0, len(integer_part), 3): + chunk = int(integer_part[i : i + 3]) + if chunk > 0: + part = self.spell_number(chunk) + unit = self.thousands[len(integer_part[i:]) // 3 - 1] + if unit: + part += " " + unit + parts.append(part) + + spelled_integer = " ".join(parts) + + # Format the spelt-out number based on conditions, such as: + # If it has decimal parts, currency symbol, minus prefix, etc + if decimal_part == "00": + return ( + f"{minus_prefix}{spelled_integer}{percent_suffix}{currency_symbol}" + if minus_prefix or currency_symbol + else f"{spelled_integer}{percent_suffix}" + ) + else: + spelled_decimal = " ".join([self.spell_number(int(digit)) for digit in decimal_part]) + return ( + f"{minus_prefix}{spelled_integer} point {spelled_decimal}{percent_suffix}{currency_symbol}" + if minus_prefix or currency_symbol + else f"{minus_prefix}{spelled_integer} point {spelled_decimal}{percent_suffix}" + ) + + def __call__(self, text): + """ + Convert numbers / number-like quantities in a string to their spelt-out counterparts + """ + # Form part of the pattern for all currency symbols + pattern = r"(? 15000, etc) + text = re.sub(r"(\d+,\d+)", lambda match: match.group(1).replace(",", ""), text) + + # Use regex to find and replace numbers in the text + converted_text = re.sub(pattern, lambda match: self.convert(match.group(1)), text) + converted_text = re.sub(" +", " ", converted_text) + + return converted_text diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/tokenization_speecht5.py b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/tokenization_speecht5.py new file mode 100644 index 0000000000000000000000000000000000000000..2b39b6180af20a0bf0672ee88dc88fb7bd670e25 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/tokenization_speecht5.py @@ -0,0 +1,166 @@ +# Copyright 2023 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for SpeechT5.""" + +from typing import Any + +from ...tokenization_utils_sentencepiece import SentencePieceBackend +from ...utils import logging +from ...utils.import_utils import requires +from .number_normalizer import EnglishNumberNormalizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spm_char.model"} + + +@requires(backends=("sentencepiece",)) +class SpeechT5Tokenizer(SentencePieceBackend): + """ + Construct a SpeechT5 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + bos_token (`str`, *optional*, defaults to `""`): + The begin of sequence token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + normalize (`bool`, *optional*, defaults to `False`): + Whether to convert numeric quantities in the text to their spelt-out english counterparts. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + is_fast = False + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + unk_token="", + pad_token="", + normalize=False, + sp_model_kwargs: dict[str, Any] | None = None, + **kwargs, + ) -> None: + self.normalize = normalize + self._normalizer = None + + # Prepare sp_model_kwargs for parent class + if sp_model_kwargs is not None: + kwargs["sp_model_kwargs"] = sp_model_kwargs + + # Call parent init (which will load sp_model) + super().__init__( + vocab_file=vocab_file, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + normalize=normalize, + **kwargs, + ) + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + normalize = kwargs.pop("normalize", self.normalize) + if is_split_into_words: + text = " " + text + if normalize: + text = self.normalizer(text) + return (text, kwargs) + + @property + def normalizer(self): + if self._normalizer is None: + self._normalizer = EnglishNumberNormalizer() + return self._normalizer + + @normalizer.setter + def normalizer(self, value): + self._normalizer = value + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> list[int]: + """Build model inputs from a sequence by appending eos_token_id.""" + if token_ids_1 is None: + return token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return token_ids_0 + token_ids_1 + [self.eos_token_id] + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: list[int] | None = None, already_has_special_tokens: bool = False + ) -> list[int]: + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + suffix_ones = [1] + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + suffix_ones + return ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: list[int] | None = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. SpeechT5 does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + eos = [self.eos_token_id] + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + token_ids_1 + eos) * [0] + + +__all__ = ["SpeechT5Tokenizer"] diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_053000.pt b/LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_053000.pt new file mode 100644 index 0000000000000000000000000000000000000000..cee59596f906b46a4ab97fd91649d6f34bd0b6d1 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_053000.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3a0ce15a3f8e0441fca84965ff658de402bc494bd94b53730430287ab2ab2df +size 927700322 diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_163000.pt b/LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_163000.pt new file mode 100644 index 0000000000000000000000000000000000000000..8bf2acb44b036e0c0cd12383ca4628a26c263954 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_163000.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c731b03de97e220aab47f37b3d6c191aa23591d0c4f973a4b38e93730358b2bf +size 927700322 diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_172000.pt b/LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_172000.pt new file mode 100644 index 0000000000000000000000000000000000000000..a3b7d208aa0b2583164317999fe7e841dc2f4086 --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_172000.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa1869d751dd1db17309c432203e4d0978a22ab1fbe9065646c5d04cfe9baa67 +size 927700322 diff --git a/LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_182000.pt b/LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_182000.pt new file mode 100644 index 0000000000000000000000000000000000000000..57dd7ae6e2742a98aeae7e53dbd6c0854d01a78d --- /dev/null +++ b/LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_182000.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ca8affe97a9e4ab92c98e52693f27b329e99dd9122eb9b7672ab56618aaf840 +size 927700322