Spaces:
Runtime error
Runtime error
| # By Forge | |
| import torch | |
| def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(x): | |
| x = x.view(torch.uint8).view(x.size(0), -1) | |
| unpacked = torch.stack([x & 15, x >> 4], dim=-1) | |
| reshaped = unpacked.view(x.size(0), -1) | |
| reshaped = reshaped.view(torch.int8) - 8 | |
| return reshaped.view(torch.int32) | |
| def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(x): | |
| x = x.view(torch.uint8).view(x.size(0), -1) | |
| unpacked = torch.stack([x & 15, x >> 4], dim=-1) | |
| reshaped = unpacked.view(x.size(0), -1) | |
| return reshaped.view(torch.int32) | |
| disable_all_optimizations = False | |
| if not hasattr(torch, 'uint16'): | |
| disable_all_optimizations = True | |
| if disable_all_optimizations: | |
| print('You are using PyTorch below version 2.3. Some optimizations will be disabled.') | |
| if not disable_all_optimizations: | |
| native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0] | |
| native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0] | |
| def quick_unpack_4bits(x): | |
| if disable_all_optimizations: | |
| return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1).view(torch.int8) - 8 | |
| global native_4bits_lookup_table | |
| s0 = x.size(0) | |
| x = x.view(torch.uint16) | |
| if native_4bits_lookup_table.device != x.device: | |
| native_4bits_lookup_table = native_4bits_lookup_table.to(device=x.device) | |
| y = torch.index_select(input=native_4bits_lookup_table, dim=0, index=x.to(dtype=torch.int32).flatten()) | |
| y = y.view(torch.int8) | |
| y = y.view(s0, -1) | |
| return y | |
| def quick_unpack_4bits_u(x): | |
| if disable_all_optimizations: | |
| return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1) | |
| global native_4bits_lookup_table_u | |
| s0 = x.size(0) | |
| x = x.view(torch.uint16) | |
| if native_4bits_lookup_table_u.device != x.device: | |
| native_4bits_lookup_table_u = native_4bits_lookup_table_u.to(device=x.device) | |
| y = torch.index_select(input=native_4bits_lookup_table_u, dim=0, index=x.to(dtype=torch.int32).flatten()) | |
| y = y.view(torch.uint8) | |
| y = y.view(s0, -1) | |
| return y | |
| def change_4bits_order(x): | |
| y = torch.stack([x & 15, x >> 4], dim=-2).view(x.size(0), -1) | |
| z = y[:, ::2] | (y[:, 1::2] << 4) | |
| return z | |