Create models/__init__.py
Browse files- models/__init__.py +55 -0
models/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
(c) Adaptation of the code from https://github.com/THUDM/GraphMAE
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .edcoder import PreModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_model(args):
|
| 9 |
+
num_heads = args.num_heads
|
| 10 |
+
num_out_heads = args.num_out_heads
|
| 11 |
+
num_hidden = args.num_hidden
|
| 12 |
+
num_layers = args.num_layers
|
| 13 |
+
residual = args.residual
|
| 14 |
+
attn_drop = args.attn_drop
|
| 15 |
+
in_drop = args.in_drop
|
| 16 |
+
norm = args.norm
|
| 17 |
+
negative_slope = args.negative_slope
|
| 18 |
+
encoder_type = args.encoder
|
| 19 |
+
decoder_type = args.decoder
|
| 20 |
+
mask_rate = args.mask_rate
|
| 21 |
+
drop_edge_rate = args.drop_edge_rate
|
| 22 |
+
replace_rate = args.replace_rate
|
| 23 |
+
batchnorm = args.batchnorm
|
| 24 |
+
|
| 25 |
+
activation = args.activation
|
| 26 |
+
loss_fn = args.loss_fn
|
| 27 |
+
alpha_l = args.alpha_l
|
| 28 |
+
concat_hidden = args.concat_hidden
|
| 29 |
+
num_features = args.num_features
|
| 30 |
+
num_edge_features = args.num_edge_features
|
| 31 |
+
|
| 32 |
+
model = PreModel(
|
| 33 |
+
in_dim=int(num_features),
|
| 34 |
+
edge_in_dim=int(num_edge_features),
|
| 35 |
+
num_hidden=int(num_hidden),
|
| 36 |
+
num_layers=num_layers,
|
| 37 |
+
nhead=num_heads,
|
| 38 |
+
nhead_out=num_out_heads,
|
| 39 |
+
activation=activation,
|
| 40 |
+
feat_drop=in_drop,
|
| 41 |
+
attn_drop=attn_drop,
|
| 42 |
+
negative_slope=negative_slope,
|
| 43 |
+
residual=residual,
|
| 44 |
+
encoder_type=encoder_type,
|
| 45 |
+
decoder_type=decoder_type,
|
| 46 |
+
mask_rate=mask_rate,
|
| 47 |
+
norm=norm,
|
| 48 |
+
loss_fn=loss_fn,
|
| 49 |
+
drop_edge_rate=drop_edge_rate,
|
| 50 |
+
replace_rate=replace_rate,
|
| 51 |
+
alpha_l=alpha_l,
|
| 52 |
+
concat_hidden=concat_hidden,
|
| 53 |
+
batchnorm=batchnorm,
|
| 54 |
+
)
|
| 55 |
+
return model
|