| from BeamDiffusionModel.models.CoSeD.sequence_predictor import SoftAttention | |
| import torch | |
| model = SoftAttention.load_from_checkpoint( | |
| "./BeamDiffusionModel/models/CoSeD/checkpoints/epoch=19-step=140.ckpt") | |
| # "/user/home/vcc.ramos/latent_training/sft/reference_training/9q3eu8vi/checkpoints/epoch=7-step=15.ckpt" | |
| model.eval() | |
| def get_softmax(previous_steps_embeddings, previous_images_embeddings, current_steps_embeddings, | |
| current_images_embeddings): | |
| previous_steps_tensor = torch.cat(previous_steps_embeddings, dim=0).to("cpu").unsqueeze(0) | |
| previous_images_tensor = torch.cat(previous_images_embeddings, dim=0).to("cpu").unsqueeze(0) | |
| current_steps_tensor = torch.cat(current_steps_embeddings).to("cpu").unsqueeze(0) | |
| current_images_tensor = torch.cat(current_images_embeddings).to("cpu").unsqueeze(0) | |
| with torch.no_grad(): | |
| softmax, logit = model(current_steps_tensor, | |
| current_images_tensor, | |
| previous_steps_tensor, | |
| previous_images_tensor | |
| ) | |
| if len(softmax.shape) <= 1: | |
| return softmax | |
| # sum and normalize the softmax values | |
| return torch.sum(softmax, dim=-1) / softmax.shape[-1] | |