dn6 HF Staff commited on
Commit
8d0a59c
·
verified ·
1 Parent(s): 7e6b4cd

Upload transformer/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. transformer/model.py +24 -6
transformer/model.py CHANGED
@@ -508,7 +508,6 @@ class Upcast(nn.Module):
508
  self.gca = GatedCrossAttention(
509
  c_query=c_atom,
510
  c_kv=c_token // n_split,
511
- c_model=c_atom,
512
  **(cross_attention_block or {}),
513
  )
514
 
@@ -601,7 +600,6 @@ class Downcast(nn.Module):
601
  self.gca = GatedCrossAttention(
602
  c_query=c_token,
603
  c_kv=c_atom,
604
- c_model=c_token,
605
  **(cross_attention_block or {}),
606
  )
607
 
@@ -948,8 +946,21 @@ class RFD3DiffusionModule(nn.Module):
948
  nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_s)),
949
  ])
950
 
951
- self.downcast_c = Downcast(c_atom=c_atom, c_token=c_s, c_s=None, method="cross_attention")
952
- self.downcast_q = Downcast(c_atom=c_atom, c_token=c_token, c_s=c_s, method="cross_attention")
 
 
 
 
 
 
 
 
 
 
 
 
 
953
  self.process_a = LinearEmbedWithPool(c_token)
954
  self.process_c = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, c_atom))
955
 
@@ -995,8 +1006,15 @@ class RFD3DiffusionModule(nn.Module):
995
  n_block=n_diffusion_blocks,
996
  )
997
 
998
- decoder_upcast = {"method": "cross_attention"}
999
- decoder_downcast = {"method": "cross_attention"}
 
 
 
 
 
 
 
1000
 
1001
  self.decoder = CompactStreamingDecoder(
1002
  c_atom=c_atom,
 
508
  self.gca = GatedCrossAttention(
509
  c_query=c_atom,
510
  c_kv=c_token // n_split,
 
511
  **(cross_attention_block or {}),
512
  )
513
 
 
600
  self.gca = GatedCrossAttention(
601
  c_query=c_token,
602
  c_kv=c_atom,
 
603
  **(cross_attention_block or {}),
604
  )
605
 
 
946
  nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_s)),
947
  ])
948
 
949
+ cross_attention_block = {
950
+ "n_head": 4,
951
+ "c_model": c_atom,
952
+ "dropout": p_drop,
953
+ "kq_norm": True,
954
+ }
955
+
956
+ self.downcast_c = Downcast(
957
+ c_atom=c_atom, c_token=c_s, c_s=None,
958
+ method="cross_attention", cross_attention_block=cross_attention_block
959
+ )
960
+ self.downcast_q = Downcast(
961
+ c_atom=c_atom, c_token=c_token, c_s=c_s,
962
+ method="cross_attention", cross_attention_block=cross_attention_block
963
+ )
964
  self.process_a = LinearEmbedWithPool(c_token)
965
  self.process_c = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, c_atom))
966
 
 
1006
  n_block=n_diffusion_blocks,
1007
  )
1008
 
1009
+ decoder_upcast = {
1010
+ "method": "cross_attention",
1011
+ "n_split": 3,
1012
+ "cross_attention_block": cross_attention_block,
1013
+ }
1014
+ decoder_downcast = {
1015
+ "method": "cross_attention",
1016
+ "cross_attention_block": cross_attention_block,
1017
+ }
1018
 
1019
  self.decoder = CompactStreamingDecoder(
1020
  c_atom=c_atom,