smoothie-diffusion-qqp / modeling_smoothie.py
yasserrmd's picture
Final successful upload
a88c182 verified
import torch
import torch.nn as nn
import math
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
class SmoothieConfig(PretrainedConfig):
model_type = "smoothie"
def __init__(self, d_model=768, num_layers=12, num_heads=12, max_seq_len=50, vocab_size=28996, conditional=True, **kwargs):
self.d_model=d_model; self.num_layers=num_layers; self.num_heads=num_heads; self.max_seq_len=max_seq_len; self.vocab_size=vocab_size; self.conditional=conditional; super().__init__(**kwargs)
class TimeEmbed(nn.Module):
def __init__(self, dim):
super().__init__(); self.dim=dim; self.mlp=nn.Sequential(nn.Linear(dim, dim*4), nn.Mish(), nn.Linear(dim*4, dim))
def forward(self, t):
half_dim = self.dim // 2; emb = math.log(10000) / (half_dim - 1); emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb); emb = t[:, None] * emb[None, :]; emb = torch.cat((emb.sin(), emb.cos()), dim=-1); return self.mlp(emb)
class SmoothieModel(PreTrainedModel):
config_class = SmoothieConfig
def __init__(self, config: SmoothieConfig):
super().__init__(config); self.config=config; self.time_embed=TimeEmbed(config.d_model); self.input_proj=nn.Linear(config.d_model, config.d_model); self.pos_embed=nn.Embedding(config.max_seq_len, config.d_model)
self.decoder_layers=nn.ModuleList([nn.TransformerDecoderLayer(d_model=config.d_model, nhead=config.num_heads, batch_first=True) for _ in range(config.num_layers)])
if config.conditional: self.encoder_embedding=nn.Embedding(config.vocab_size, config.d_model); encoder_layer=nn.TransformerEncoderLayer(d_model=config.d_model, nhead=config.num_heads, batch_first=True); self.encoder=nn.TransformerEncoder(encoder_layer, num_layers=6)
self.output_proj=nn.Linear(config.d_model, config.d_model)
def forward(self, weighted_avg_emb, t, src_tokens=None, src_mask=None, **kwargs):
device=weighted_avg_emb.device; x=self.input_proj(weighted_avg_emb); time_emb=self.time_embed(t).unsqueeze(1); pos=torch.arange(0, weighted_avg_emb.size(1), device=device).unsqueeze(0); pos_emb=self.pos_embed(pos); x=x+time_emb+pos_emb
encoder_output=None
if self.config.conditional and src_tokens is not None: src_emb=self.encoder_embedding(src_tokens); encoder_output=self.encoder(src_emb, src_key_padding_mask=src_mask)
skip_connections=[]; num_layers=len(self.decoder_layers)
for i in range(num_layers // 2): skip_connections.append(x); x=self.decoder_layers[i](tgt=x, memory=encoder_output, memory_key_padding_mask=src_mask)
x=self.decoder_layers[num_layers // 2](tgt=x, memory=encoder_output, memory_key_padding_mask=src_mask)
for i in range(num_layers // 2 + 1, num_layers): x=x+skip_connections.pop(); x=self.decoder_layers[i](tgt=x, memory=encoder_output, memory_key_padding_mask=src_mask)
return self.output_proj(x)
AutoConfig.register("smoothie", SmoothieConfig)
AutoModel.register(SmoothieConfig, SmoothieModel)