Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from torch import Tensor | |
| import matplotlib.pyplot as plt | |
| from s_multimae.model.multimae import build_2d_sincos_posemb | |
| def visualize_2d_posemb(): | |
| NH, NW = 14, 14 | |
| dim_tokens = 768 | |
| colors = [ | |
| "Greys", | |
| "Purples", | |
| "Blues", | |
| "Greens", | |
| "Oranges", | |
| "Reds", | |
| "YlOrBr", | |
| "YlOrRd", | |
| "OrRd", | |
| "PuRd", | |
| "RdPu", | |
| "BuPu", | |
| "GnBu", | |
| "PuBu", | |
| "YlGnBu", | |
| "PuBuGn", | |
| "BuGn", | |
| "YlGn", | |
| ] | |
| pos_emb: Tensor = build_2d_sincos_posemb(NH, NW, dim_tokens) | |
| pos_emb_numpy: np.ndarray = ( | |
| pos_emb.squeeze(0).permute(1, 2, 0).numpy() | |
| ) # 14 x 14 x 768 | |
| x = np.linspace(0, NH - 1, NH) | |
| y = np.linspace(0, NW - 1, NW) | |
| X, Y = np.meshgrid(x, y) | |
| for color, i in zip(colors, range(0, pos_emb_numpy.shape[2], 100)): | |
| ax = plt.axes(projection="3d") | |
| Z = pos_emb_numpy[:, :, i] | |
| # plt.imshow(Z, cmap='viridis') | |
| # plt.savefig(f'posemb_visualization/test_{i}.png') | |
| ax.plot_surface( | |
| X, | |
| Y, | |
| Z, | |
| # rstride=1, cstride=1, | |
| cmap="viridis", | |
| edgecolor="none", | |
| ) | |
| plt.show() | |
| plt.savefig(f"posemb_visualization/test_{i}.png") | |