Update graph_decoder/transformer.py
Browse files
graph_decoder/transformer.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from .layers import Attention, MLP
|
| 4 |
from .conditions import TimestepEmbedder, ConditionEmbedder
|
| 5 |
-
from .diffusion_utils import PlaceHolder
|
| 6 |
|
| 7 |
def modulate(x, shift, scale):
|
| 8 |
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
@@ -98,7 +98,8 @@ class Transformer(nn.Module):
|
|
| 98 |
|
| 99 |
# X: B * N * dx, E: B * N * N * de
|
| 100 |
X, E = self.output_layer(X, X_in, E_in, c, t, node_mask)
|
| 101 |
-
return PlaceHolder(X=X, E=E, y=None).mask(node_mask)
|
|
|
|
| 102 |
|
| 103 |
class Block(nn.Module):
|
| 104 |
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from .layers import Attention, MLP
|
| 4 |
from .conditions import TimestepEmbedder, ConditionEmbedder
|
| 5 |
+
# from .diffusion_utils import PlaceHolder
|
| 6 |
|
| 7 |
def modulate(x, shift, scale):
|
| 8 |
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
|
|
| 98 |
|
| 99 |
# X: B * N * dx, E: B * N * N * de
|
| 100 |
X, E = self.output_layer(X, X_in, E_in, c, t, node_mask)
|
| 101 |
+
# return PlaceHolder(X=X, E=E, y=None).mask(node_mask)
|
| 102 |
+
return X, E
|
| 103 |
|
| 104 |
class Block(nn.Module):
|
| 105 |
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|