|
|
from lightning import Callback |
|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
import os |
|
|
import numpy as np |
|
|
import torchvision |
|
|
from einops import rearrange |
|
|
|
|
|
class VisualizationCallback(Callback): |
|
|
def __init__(self, save_freq=2000, output_dir="visualizations"): |
|
|
self.save_freq = save_freq |
|
|
self.output_dir = output_dir |
|
|
if not os.path.exists(self.output_dir): |
|
|
os.makedirs(self.output_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_train_batch_start(self, trainer, model, batch, batch_idx): |
|
|
|
|
|
if trainer.is_global_zero: |
|
|
global_step = trainer.global_step |
|
|
if global_step % self.save_freq == 0: |
|
|
|
|
|
|
|
|
self.save_visualization(trainer, model, global_step, batch) |
|
|
|
|
|
def save_visualization(self, trainer, model, global_step, batch): |
|
|
|
|
|
fig, ax = plt.subplots() |
|
|
ax.plot([1, 2, 3], [4, 5, 6]) |
|
|
ax.set_title(f"Visualization at Step {global_step}") |
|
|
|
|
|
|
|
|
plt.savefig(f"{self.output_dir}/visualization_{global_step}.png") |
|
|
plt.close(fig) |
|
|
print(f"Saved visualization at step {global_step}") |
|
|
|
|
|
|
|
|
class VisualizationVAECallback(VisualizationCallback): |
|
|
def __init__(self, save_freq=2000, output_dir="visualizations"): |
|
|
super().__init__(save_freq, output_dir) |
|
|
|
|
|
def save_visualization(self, trainer, model, global_step, batch): |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
x_pred, x_gt = model(batch) |
|
|
|
|
|
x_pred = x_pred.cpu() |
|
|
x_gt = x_gt.cpu() |
|
|
|
|
|
x_pred = torch.clamp(x_pred, min=0.0, max=1.0) |
|
|
x_gt = torch.clamp(x_gt, min=0.0, max=1.0) |
|
|
|
|
|
B = x_gt.shape[0] |
|
|
rows = int(np.ceil(np.sqrt(B))) |
|
|
cols = int(np.ceil(B / rows)) |
|
|
|
|
|
gt_grid = torchvision.utils.make_grid(x_gt, nrow=rows) |
|
|
pred_grid = torchvision.utils.make_grid(x_pred, nrow=rows) |
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(12, 6)) |
|
|
axes[0].imshow(gt_grid.permute(1, 2, 0)) |
|
|
axes[0].axis('off') |
|
|
|
|
|
|
|
|
axes[1].imshow(pred_grid.permute(1, 2, 0)) |
|
|
axes[1].axis('off') |
|
|
|
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
plt.savefig(f"{self.output_dir}/image_grid_{global_step}.png") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Visualization_HeadAnimator_Callback(VisualizationCallback): |
|
|
def __init__(self, save_freq=2000, output_dir="visualizations"): |
|
|
super().__init__(save_freq, output_dir) |
|
|
|
|
|
def save_visualization(self, trainer, model, global_step, batch): |
|
|
|
|
|
|
|
|
masked_target_vid = batch['pixel_values_vid'] |
|
|
masked_ref_img = batch['pixel_values_ref_img'] |
|
|
|
|
|
ref_img_original = batch['ref_img_original'] |
|
|
target_vid_original = batch['pixel_values_vid_original'] |
|
|
|
|
|
|
|
|
masked_ref_img = masked_ref_img[:,None].repeat(1, masked_target_vid.size(1), 1, 1, 1) |
|
|
masked_ref_img = rearrange(masked_ref_img, "b t c h w -> (b t) c h w") |
|
|
masked_target_vid = rearrange(masked_target_vid, "b t c h w -> (b t) c h w") |
|
|
|
|
|
ref_img_original = ref_img_original[:,None].repeat(1, target_vid_original.size(1), 1, 1, 1) |
|
|
ref_img_original = rearrange(ref_img_original, "b t c h w -> (b t) c h w") |
|
|
target_vid_original = rearrange(target_vid_original, "b t c h w -> (b t) c h w") |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
model_out = model.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid) |
|
|
x_pred = model_out['recon_img'] |
|
|
x_gt = target_vid_original |
|
|
|
|
|
x_pred = x_pred.cpu() |
|
|
x_gt = x_gt.cpu() |
|
|
x_ref = ref_img_original.cpu() |
|
|
|
|
|
if x_gt.min() < -0.5: |
|
|
x_gt = (x_gt + 1) / 2 |
|
|
x_pred = (x_pred + 1) / 2 |
|
|
x_ref = (x_ref + 1) / 2 |
|
|
|
|
|
x_pred = torch.clamp(x_pred, min=0.0, max=1.0) |
|
|
x_gt = torch.clamp(x_gt, min=0.0, max=1.0) |
|
|
x_ref = torch.clamp(x_ref, min=0.0, max=1.0) |
|
|
|
|
|
B = x_gt.shape[0] |
|
|
rows = int(np.ceil(np.sqrt(B))) |
|
|
cols = int(np.ceil(B / rows)) |
|
|
|
|
|
ref_grid = torchvision.utils.make_grid(x_ref, nrow=rows) |
|
|
gt_grid = torchvision.utils.make_grid(x_gt, nrow=rows) |
|
|
pred_grid = torchvision.utils.make_grid(x_pred, nrow=rows) |
|
|
|
|
|
diff = (x_pred-x_gt).abs() |
|
|
diff_grid = torchvision.utils.make_grid(diff, nrow=rows) |
|
|
|
|
|
fig, axes = plt.subplots(1, 4, figsize=(12, 6)) |
|
|
axes[0].imshow(ref_grid.permute(1, 2, 0)) |
|
|
axes[0].axis('off') |
|
|
|
|
|
axes[1].imshow(gt_grid.permute(1, 2, 0)) |
|
|
axes[1].axis('off') |
|
|
|
|
|
axes[2].imshow(pred_grid.permute(1, 2, 0)) |
|
|
axes[2].axis('off') |
|
|
|
|
|
axes[3].imshow(diff_grid.permute(1, 2, 0), cmap='jet') |
|
|
axes[3].axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
plt.savefig(f"{self.output_dir}/image_grid_{global_step}.png") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|