new-language_model / utils /encoder_utils.py
Raja-65's picture
Add ELF demo
6ab8280
Raw
History Blame Contribute Delete
989 Bytes
import torch
import numpy as np
@torch.no_grad()
def encode_text(
input_ids,
attention_mask,
encoder,
latent_mean,
latent_std,
use_bf16=True,
):
"""Encoder pass from text to latent with normalization."""
autocast_enabled = bool(use_bf16) and input_ids.is_cuda
with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=autocast_enabled):
latents = encoder(input_ids=input_ids, attention_mask=attention_mask, deterministic=True)
return (latents - latent_mean) / latent_std
def build_self_attn_cond_masks(is_cond, is_valid, xp=np):
"""Build self-attention conditioning masks from cond/valid token flags."""
encoder_attention_mask = (
(is_cond[:, :, None] & is_cond[:, None, :]) |
(~is_cond[:, :, None] & is_valid[:, None, :])
).astype(xp.float32)
attention_mask = is_valid.astype(xp.float32)
cond_seq_mask = is_cond.astype(xp.float32)
return encoder_attention_mask, attention_mask, cond_seq_mask