Upload transformer/model.py with huggingface_hub
Browse files- transformer/model.py +24 -6
transformer/model.py
CHANGED
|
@@ -508,7 +508,6 @@ class Upcast(nn.Module):
|
|
| 508 |
self.gca = GatedCrossAttention(
|
| 509 |
c_query=c_atom,
|
| 510 |
c_kv=c_token // n_split,
|
| 511 |
-
c_model=c_atom,
|
| 512 |
**(cross_attention_block or {}),
|
| 513 |
)
|
| 514 |
|
|
@@ -601,7 +600,6 @@ class Downcast(nn.Module):
|
|
| 601 |
self.gca = GatedCrossAttention(
|
| 602 |
c_query=c_token,
|
| 603 |
c_kv=c_atom,
|
| 604 |
-
c_model=c_token,
|
| 605 |
**(cross_attention_block or {}),
|
| 606 |
)
|
| 607 |
|
|
@@ -948,8 +946,21 @@ class RFD3DiffusionModule(nn.Module):
|
|
| 948 |
nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_s)),
|
| 949 |
])
|
| 950 |
|
| 951 |
-
|
| 952 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 953 |
self.process_a = LinearEmbedWithPool(c_token)
|
| 954 |
self.process_c = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, c_atom))
|
| 955 |
|
|
@@ -995,8 +1006,15 @@ class RFD3DiffusionModule(nn.Module):
|
|
| 995 |
n_block=n_diffusion_blocks,
|
| 996 |
)
|
| 997 |
|
| 998 |
-
decoder_upcast = {
|
| 999 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1000 |
|
| 1001 |
self.decoder = CompactStreamingDecoder(
|
| 1002 |
c_atom=c_atom,
|
|
|
|
| 508 |
self.gca = GatedCrossAttention(
|
| 509 |
c_query=c_atom,
|
| 510 |
c_kv=c_token // n_split,
|
|
|
|
| 511 |
**(cross_attention_block or {}),
|
| 512 |
)
|
| 513 |
|
|
|
|
| 600 |
self.gca = GatedCrossAttention(
|
| 601 |
c_query=c_token,
|
| 602 |
c_kv=c_atom,
|
|
|
|
| 603 |
**(cross_attention_block or {}),
|
| 604 |
)
|
| 605 |
|
|
|
|
| 946 |
nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_s)),
|
| 947 |
])
|
| 948 |
|
| 949 |
+
cross_attention_block = {
|
| 950 |
+
"n_head": 4,
|
| 951 |
+
"c_model": c_atom,
|
| 952 |
+
"dropout": p_drop,
|
| 953 |
+
"kq_norm": True,
|
| 954 |
+
}
|
| 955 |
+
|
| 956 |
+
self.downcast_c = Downcast(
|
| 957 |
+
c_atom=c_atom, c_token=c_s, c_s=None,
|
| 958 |
+
method="cross_attention", cross_attention_block=cross_attention_block
|
| 959 |
+
)
|
| 960 |
+
self.downcast_q = Downcast(
|
| 961 |
+
c_atom=c_atom, c_token=c_token, c_s=c_s,
|
| 962 |
+
method="cross_attention", cross_attention_block=cross_attention_block
|
| 963 |
+
)
|
| 964 |
self.process_a = LinearEmbedWithPool(c_token)
|
| 965 |
self.process_c = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, c_atom))
|
| 966 |
|
|
|
|
| 1006 |
n_block=n_diffusion_blocks,
|
| 1007 |
)
|
| 1008 |
|
| 1009 |
+
decoder_upcast = {
|
| 1010 |
+
"method": "cross_attention",
|
| 1011 |
+
"n_split": 3,
|
| 1012 |
+
"cross_attention_block": cross_attention_block,
|
| 1013 |
+
}
|
| 1014 |
+
decoder_downcast = {
|
| 1015 |
+
"method": "cross_attention",
|
| 1016 |
+
"cross_attention_block": cross_attention_block,
|
| 1017 |
+
}
|
| 1018 |
|
| 1019 |
self.decoder = CompactStreamingDecoder(
|
| 1020 |
c_atom=c_atom,
|