from BeamDiffusionModel.models.CoSeD.sequence_predictor import SoftAttention import torch model = SoftAttention.load_from_checkpoint( "./BeamDiffusionModel/models/CoSeD/checkpoints/epoch=19-step=140.ckpt") # "/user/home/vcc.ramos/latent_training/sft/reference_training/9q3eu8vi/checkpoints/epoch=7-step=15.ckpt" model.eval() def get_softmax(previous_steps_embeddings, previous_images_embeddings, current_steps_embeddings, current_images_embeddings): previous_steps_tensor = torch.cat(previous_steps_embeddings, dim=0).to("cpu").unsqueeze(0) previous_images_tensor = torch.cat(previous_images_embeddings, dim=0).to("cpu").unsqueeze(0) current_steps_tensor = torch.cat(current_steps_embeddings).to("cpu").unsqueeze(0) current_images_tensor = torch.cat(current_images_embeddings).to("cpu").unsqueeze(0) with torch.no_grad(): softmax, logit = model(current_steps_tensor, current_images_tensor, previous_steps_tensor, previous_images_tensor ) if len(softmax.shape) <= 1: return softmax # sum and normalize the softmax values return torch.sum(softmax, dim=-1) / softmax.shape[-1]