| import math | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from scipy.optimize import minimize_scalar | |
| torch.set_float32_matmul_precision('high') | |
| torch.manual_seed(0) | |
| def opt_err_cvx(fn): | |
| res = minimize_scalar(fn, bounds=(0.1, 100)) | |
| scale = res.x.item() | |
| err = res.fun | |
| return err, scale | |
| def round(X, grid, grid_norm): | |
| Xqidx = (2 * X @ grid.T - grid_norm).argmax(-1) | |
| return grid[Xqidx] | |
| def get_hint_curve(bit_cap=4, cols=1): | |
| def round_mvn(grid, dim=cols, nsamples=50000, sample_bs=500): | |
| X = torch.distributions.multivariate_normal.MultivariateNormal( | |
| torch.zeros(dim), torch.eye(dim)).rsample([nsamples]).to(grid.dtype).to(grid.device).abs() | |
| grid_norm = grid.norm(dim=-1)**2 | |
| def test_s(s): | |
| total_err = 0 | |
| for i in range(nsamples//sample_bs): | |
| sample_b = X[i*sample_bs: (i+1)*sample_bs].cuda() | |
| total_err += (round(sample_b*s, grid, grid_norm)/s - sample_b).float().norm()**2 / torch.numel(sample_b) | |
| total_err = total_err/(nsamples//sample_bs) | |
| return total_err.cpu() | |
| return opt_err_cvx(test_s) | |
| bits = 0 | |
| last_bits = 0 | |
| cr = 1 | |
| data = [[], [], []] | |
| while bits < bit_cap: | |
| base_grid = torch.arange(0, cr).to(torch.float16) | |
| grid = torch.cartesian_prod(*[base_grid + 1/2] * cols) | |
| if cols == 1: | |
| grid = grid.unsqueeze(-1) | |
| grid_norms = torch.sum(grid**2, dim=-1) | |
| norms = torch.unique(grid_norms) | |
| norms = norms[torch.where((norms >= (cr - 1)**2) * (norms < cr**2))[0]] | |
| for norm in norms[::4]: | |
| cb = grid[torch.where(grid_norms <= norm)[0]].cuda() | |
| bits = math.log(len(cb))/math.log(2)/cols + 1 | |
| if bits - last_bits < 0.1: | |
| continue | |
| last_bits = bits | |
| data[0].append(bits) | |
| err, scale = round_mvn(cb.cuda()) | |
| data[1].append(err) | |
| data[2].append(scale) | |
| print(norm.item(), bits, err, scale) | |
| if bits > bit_cap: | |
| return data | |
| cr += 1 | |
| return data | |
| def get_D4_curve(bit_cap=4): | |
| def round_mvn(grid, nsamples=50000, sample_bs=1000): | |
| dim = grid.shape[-1] | |
| X = torch.distributions.multivariate_normal.MultivariateNormal( | |
| torch.zeros(dim), torch.eye(dim)).rsample([nsamples]).to(grid.dtype).to(grid.device) | |
| grid_norm = grid.norm(dim=-1)**2 | |
| def test_s(s): | |
| err = (round(X*s, grid, grid_norm)/s - X).float().norm()**2 / torch.numel(X) | |
| return err.cpu() | |
| return opt_err_cvx(test_s) | |
| _D4_CODESZ = 4 | |
| bits = 0 | |
| last_bits = 0 | |
| cr = 1 | |
| data = [[], [], []] | |
| while bits < bit_cap: | |
| base_grid = torch.arange(-cr, cr).to(torch.float16) | |
| grid = torch.cartesian_prod(*[base_grid + 1/2] * _D4_CODESZ) | |
| grid = grid[torch.where(grid.sum(dim=-1) % 2 == 0)[0]] | |
| grid_norms = torch.sum(grid**2, dim=-1) | |
| norms = torch.unique(grid_norms) | |
| norms = norms[torch.where((norms >= (cr - 1)**2) * (norms < cr**2))[0]] | |
| for norm in norms[::4]: | |
| cb = grid[torch.where(grid_norms <= norm)[0]].cuda() | |
| bits = math.log(len(cb))/math.log(2)/_D4_CODESZ | |
| if bits - last_bits < 0.1: | |
| continue | |
| last_bits = bits | |
| data[0].append(bits) | |
| err, scale = round_mvn(cb.cuda()) | |
| data[1].append(err) | |
| data[2].append(scale) | |
| print(norm.item(), bits, err, scale) | |
| if bits > bit_cap: | |
| return data | |
| cr += 1 | |
| return data | |
| def get_E8_curve(bit_cap=4): | |
| def round_mvn(grid, nsamples=50000, sample_bs=250): | |
| dim = grid.shape[-1] | |
| X = torch.distributions.multivariate_normal.MultivariateNormal( | |
| torch.zeros(dim), torch.eye(dim)).rsample([nsamples]).to(grid.dtype) | |
| X_part = torch.abs(X) | |
| X_odd = torch.where((X < 0).sum(dim=-1) % 2 != 0)[0] | |
| X_part[X_odd, 0] = -X_part[X_odd, 0] | |
| X = X_part | |
| grid_norm = grid.norm(dim=-1)**2 | |
| def test_s(s): | |
| total_err = 0 | |
| for i in range(nsamples//sample_bs): | |
| sample_b = X[i*sample_bs: (i+1)*sample_bs].cuda() | |
| total_err += (round(sample_b*s, grid, grid_norm)/s - sample_b).float().norm()**2 / torch.numel(sample_b) | |
| total_err = total_err/(nsamples//sample_bs) | |
| return total_err.cpu() | |
| return opt_err_cvx(test_s) | |
| def flip_cb(cb, flips, batch_size=5000000): | |
| map = 1 - 2*flips | |
| output = torch.zeros((len(cb), len(map), cb.shape[-1]), dtype=cb.dtype, device='cpu') | |
| map = map.unsqueeze(0) | |
| for i in range(math.ceil(len(cb)/batch_size)): | |
| next = min(len(cb), (i+1)*batch_size) | |
| output[i*batch_size: next] = (cb[i*batch_size:next].unsqueeze(1)*map).cpu() | |
| return output.reshape(-1, cb.shape[-1]) | |
| def batched_unique(cpu_tensor, batch_size=10**9): | |
| res = [] | |
| for i in range(math.ceil(len(cpu_tensor)/batch_size)): | |
| next = min(len(cpu_tensor), (i+1)*batch_size) | |
| res.append(torch.unique(cpu_tensor[i*batch_size:next].cuda(), dim=0).cpu()) | |
| return torch.concat(res, dim=0) | |
| def combo(n, k): | |
| return ((n + 1).lgamma() - (k + 1).lgamma() - ((n - k) + 1).lgamma()).exp() | |
| _E8_CODESZ = 8 | |
| int_map = 2**torch.arange(8) | |
| bitmap = torch.zeros(256, 8) | |
| for i in range(256): | |
| bitmap[i] = (i & int_map) != 0 | |
| bitmap = bitmap[torch.where(bitmap.sum(dim=-1)%2 == 0)[0]].cuda() | |
| bits = 0 | |
| cr = 2 | |
| data = [[], [], []] | |
| last_bits = 0 | |
| while bits < bit_cap: | |
| base_grid = torch.arange(-1, cr).to(torch.float16) | |
| int_grid = torch.cartesian_prod(*[base_grid] * _E8_CODESZ) | |
| int_grid = int_grid[torch.where(int_grid.sum(dim=-1) % 2 == 0)[0]] | |
| hint_grid = torch.cartesian_prod(*[base_grid + 1/2] * _E8_CODESZ) | |
| hint_grid = hint_grid[torch.where(hint_grid.sum(dim=-1) % 2 == 0)[0]] | |
| grid = torch.concat([int_grid, hint_grid], dim=0) | |
| grid_norms = torch.sum(grid**2, dim=-1) | |
| norms = torch.unique(grid_norms) | |
| norms = norms[torch.where((norms >= (cr - 1)**2) * (norms < cr**2))[0]] | |
| for norm in norms[::4]: | |
| cb = grid[torch.where(grid_norms <= norm)[0]].cuda() | |
| cb = batched_unique(flip_cb(cb, bitmap)) | |
| idxs = torch.where( | |
| ((cb[:, 1:] < 0).sum(dim=-1) <= 1) * \ | |
| (cb[:, 1:].min(dim=-1).values >= -0.5) | |
| )[0] | |
| cb_part = cb[idxs] | |
| bits = math.log(len(cb))/math.log(2)/_E8_CODESZ | |
| if bits - last_bits < 0.1: | |
| continue | |
| last_bits = bits | |
| data[0].append(bits) | |
| err, scale = round_mvn(cb_part.cuda()) | |
| data[1].append(err) | |
| data[2].append(scale) | |
| print(norm.item(), bits, err, scale) | |
| if bits > bit_cap: | |
| return data | |
| cr += 1 | |
| return data | |
| def parse_cached(s): | |
| s = s.replace('\n', ' ') | |
| s = s.strip().rstrip().split(' ') | |
| bits = [float(_) for _ in s[1::3]] | |
| err = [float(_) for _ in s[2::3]] | |
| return bits, err | |
| bit_cap = 3.5 | |
| hint_1c = get_hint_curve(bit_cap, 1) | |
| hint_4c = get_hint_curve(bit_cap, 4) | |
| hint_8c = get_hint_curve(bit_cap, 8) | |
| D4 = get_D4_curve(bit_cap) | |
| E8 = get_E8_curve(bit_cap) | |
| import pickle as pkl | |
| all_data = { | |
| 'half_int_1col': hint_1c, | |
| 'half_int_4col': hint_4c, | |
| 'half_int_8col': hint_8c, | |
| 'D4': D4, | |
| 'E8': E8, | |
| } | |
| print(all_data) | |
| pkl.dump(all_data, open('plot_data.pkl', 'wb')) | |
| exit() | |
| plt.rcParams["figure.figsize"] = (6,5) | |
| plt.cla() | |
| box = plt.plot(hint_1c[0], hint_1c[1], 's', label='Half Integer 1 Column')[0] | |
| plt.plot(hint_1c[0], hint_1c[1], '-', alpha=0.5, color=box._color) | |
| box = plt.plot(hint_4c[0], hint_4c[1], 'o', label='Half Integer 4 Column')[0] | |
| plt.plot(hint_4c[0], hint_4c[1], '-', alpha=0.5, color=box._color) | |
| box = plt.plot(hint_8c[0], hint_8c[1], '+', label='Half Integer 8 Column')[0] | |
| plt.plot(hint_8c[0], hint_8c[1], '-', alpha=0.5, color=box._color) | |
| box = plt.plot(D4[0], D4[1], '*', label='D4')[0] | |
| plt.plot(D4[0], D4[1], '-', alpha=0.5, color=box._color) | |
| box = plt.plot(E8[0], E8[1], 'x', label='E8')[0] | |
| plt.plot(E8[0], E8[1], '-', alpha=0.5, color=box._color) | |
| plt.plot(2.0, 0.0915, 'yD', label='E8 Padded ($2^{16}$ entries)') | |
| plt.legend() | |
| plt.title('Lowest MSE Achievable for a Multivariate Gaussian') | |
| plt.ylabel('MSE') | |
| plt.yscale('log') | |
| plt.xlabel('Bits') | |
| plt.tight_layout() | |
| plt.savefig('lattice_err.png', dpi=600) | |