robinwitch's picture
fix error
9ad5b1d
from lightning import Callback
import torch
import matplotlib.pyplot as plt
import os
import numpy as np
import torchvision
from einops import rearrange
class VisualizationCallback(Callback):
def __init__(self, save_freq=2000, output_dir="visualizations"):
self.save_freq = save_freq
self.output_dir = output_dir
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
# def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx):
# # Check if the current step is a multiple of save_freq
# if trainer.is_global_zero:
# global_step = trainer.global_step
# if global_step % self.save_freq == 0:
# # Perform your visualization logic here
# # Example: save a plot of the current input or output (you can replace it with any other visualization)
# self.save_visualization(trainer, model, global_step, batch)
def on_train_batch_start(self, trainer, model, batch, batch_idx):
# Check if the current step is a multiple of save_freq
if trainer.is_global_zero:
global_step = trainer.global_step
if global_step % self.save_freq == 0:
# Perform your visualization logic here
# Example: save a plot of the current input or output (you can replace it with any other visualization)
self.save_visualization(trainer, model, global_step, batch)
def save_visualization(self, trainer, model, global_step, batch):
# Example visualization: Save a plot of a dummy tensor (replace with your actual data or outputs)
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [4, 5, 6]) # Replace with actual data, such as outputs from the model
ax.set_title(f"Visualization at Step {global_step}")
# Save the plot to a file
plt.savefig(f"{self.output_dir}/visualization_{global_step}.png")
plt.close(fig)
print(f"Saved visualization at step {global_step}")
class VisualizationVAECallback(VisualizationCallback):
def __init__(self, save_freq=2000, output_dir="visualizations"):
super().__init__(save_freq, output_dir)
def save_visualization(self, trainer, model, global_step, batch):
# Example visualization: Save a plot of a dummy tensor (replace with your actual data or outputs)
model.eval()
with torch.no_grad():
x_pred, x_gt = model(batch)
x_pred = x_pred.cpu()
x_gt = x_gt.cpu()
x_pred = torch.clamp(x_pred, min=0.0, max=1.0)
x_gt = torch.clamp(x_gt, min=0.0, max=1.0)
B = x_gt.shape[0]
rows = int(np.ceil(np.sqrt(B)))
cols = int(np.ceil(B / rows))
gt_grid = torchvision.utils.make_grid(x_gt, nrow=rows)
pred_grid = torchvision.utils.make_grid(x_pred, nrow=rows)
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(gt_grid.permute(1, 2, 0))
axes[0].axis('off')
# axes[0].set_title('Ground Truth')
axes[1].imshow(pred_grid.permute(1, 2, 0))
axes[1].axis('off')
# axes[1].set_title('Prediction')
plt.tight_layout()
plt.show()
plt.savefig(f"{self.output_dir}/image_grid_{global_step}.png")
plt.close()
# import pdb; pdb.set_trace()
class Visualization_HeadAnimator_Callback(VisualizationCallback):
def __init__(self, save_freq=2000, output_dir="visualizations"):
super().__init__(save_freq, output_dir)
def save_visualization(self, trainer, model, global_step, batch):
# Example visualization: Save a plot of a dummy tensor (replace with your actual data or outputs)
masked_target_vid = batch['pixel_values_vid'] # this is a video batch: [B, T, C, H, W]
masked_ref_img = batch['pixel_values_ref_img']
ref_img_original = batch['ref_img_original']
target_vid_original = batch['pixel_values_vid_original']
# construct ref-tgt pairs
masked_ref_img = masked_ref_img[:,None].repeat(1, masked_target_vid.size(1), 1, 1, 1)
masked_ref_img = rearrange(masked_ref_img, "b t c h w -> (b t) c h w")
masked_target_vid = rearrange(masked_target_vid, "b t c h w -> (b t) c h w")
ref_img_original = ref_img_original[:,None].repeat(1, target_vid_original.size(1), 1, 1, 1)
ref_img_original = rearrange(ref_img_original, "b t c h w -> (b t) c h w")
target_vid_original = rearrange(target_vid_original, "b t c h w -> (b t) c h w")
with torch.no_grad():
# get reconstructed image
model_out = model.forward(ref_img_original, target_vid_original, masked_ref_img, masked_target_vid)
x_pred = model_out['recon_img']
x_gt = target_vid_original
x_pred = x_pred.cpu()
x_gt = x_gt.cpu()
x_ref = ref_img_original.cpu()
if x_gt.min() < -0.5:
x_gt = (x_gt + 1) / 2
x_pred = (x_pred + 1) / 2
x_ref = (x_ref + 1) / 2
x_pred = torch.clamp(x_pred, min=0.0, max=1.0)
x_gt = torch.clamp(x_gt, min=0.0, max=1.0)
x_ref = torch.clamp(x_ref, min=0.0, max=1.0)
B = x_gt.shape[0]
rows = int(np.ceil(np.sqrt(B)))
cols = int(np.ceil(B / rows))
ref_grid = torchvision.utils.make_grid(x_ref, nrow=rows)
gt_grid = torchvision.utils.make_grid(x_gt, nrow=rows)
pred_grid = torchvision.utils.make_grid(x_pred, nrow=rows)
diff = (x_pred-x_gt).abs()
diff_grid = torchvision.utils.make_grid(diff, nrow=rows)
fig, axes = plt.subplots(1, 4, figsize=(12, 6))
axes[0].imshow(ref_grid.permute(1, 2, 0))
axes[0].axis('off')
axes[1].imshow(gt_grid.permute(1, 2, 0))
axes[1].axis('off')
axes[2].imshow(pred_grid.permute(1, 2, 0))
axes[2].axis('off')
axes[3].imshow(diff_grid.permute(1, 2, 0), cmap='jet')
axes[3].axis('off')
plt.tight_layout()
plt.show()
plt.savefig(f"{self.output_dir}/image_grid_{global_step}.png")
plt.close()