Maxlegrec commited on
Commit
8a58a50
·
verified ·
1 Parent(s): cf44e3f

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +39 -24
model.py CHANGED
@@ -10,7 +10,7 @@ import bulletchess
10
  import numpy as np
11
 
12
 
13
- from transformers import PretrainedConfig
14
 
15
  class Gating(nn.Module):
16
  def __init__(self, features_shape, additive=True, init_value=None):
@@ -280,31 +280,46 @@ class BT4Config(PretrainedConfig):
280
  self.smol_gen_sz = smol_gen_sz
281
  self.smol_activation = smol_activation
282
 
283
- class BT4(nn.Module):
284
  def __init__(self, config=None, embedding_size=1024, embedding_dense_sz=512, encoder_layers=15, encoder_d_model=1024, encoder_heads=32, encoder_dff=1536, dropout_rate=0.0, pol_embedding_size=1024, policy_d_model=1024, val_embedding_size=128, default_activation=Mish(),
285
  use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'):
286
- super(BT4, self).__init__()
287
-
288
- # Store config if provided
289
- self.config = config
290
-
291
- # If config is provided, use it to override parameters
292
- if config is not None:
293
- embedding_size = getattr(config, 'embedding_size', embedding_size)
294
- embedding_dense_sz = getattr(config, 'embedding_dense_sz', embedding_dense_sz)
295
- encoder_layers = getattr(config, 'encoder_layers', encoder_layers)
296
- encoder_d_model = getattr(config, 'encoder_d_model', encoder_d_model)
297
- encoder_heads = getattr(config, 'encoder_heads', encoder_heads)
298
- encoder_dff = getattr(config, 'encoder_dff', encoder_dff)
299
- dropout_rate = getattr(config, 'dropout_rate', dropout_rate)
300
- pol_embedding_size = getattr(config, 'pol_embedding_size', pol_embedding_size)
301
- policy_d_model = getattr(config, 'policy_d_model', policy_d_model)
302
- val_embedding_size = getattr(config, 'val_embedding_size', val_embedding_size)
303
- use_smolgen = getattr(config, 'use_smolgen', use_smolgen)
304
- smol_hidden_channels = getattr(config, 'smol_hidden_channels', smol_hidden_channels)
305
- smol_hidden_sz = getattr(config, 'smol_hidden_sz', smol_hidden_sz)
306
- smol_gen_sz = getattr(config, 'smol_gen_sz', smol_gen_sz)
307
- smol_activation = getattr(config, 'smol_activation', smol_activation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  self.embedding_dense_sz = embedding_dense_sz
309
  # DeepNorm alpha used in embedding residual; default uses provided encoder_layers
310
  self.deepnorm_alpha = (2. * encoder_layers) ** -0.25
 
10
  import numpy as np
11
 
12
 
13
+ from transformers import PretrainedConfig, PreTrainedModel
14
 
15
  class Gating(nn.Module):
16
  def __init__(self, features_shape, additive=True, init_value=None):
 
280
  self.smol_gen_sz = smol_gen_sz
281
  self.smol_activation = smol_activation
282
 
283
+ class BT4(PreTrainedModel):
284
  def __init__(self, config=None, embedding_size=1024, embedding_dense_sz=512, encoder_layers=15, encoder_d_model=1024, encoder_heads=32, encoder_dff=1536, dropout_rate=0.0, pol_embedding_size=1024, policy_d_model=1024, val_embedding_size=128, default_activation=Mish(),
285
  use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'):
286
+ # Initialize PreTrainedModel with config
287
+ if config is None:
288
+ config = BT4Config(
289
+ embedding_size=embedding_size,
290
+ embedding_dense_sz=embedding_dense_sz,
291
+ encoder_layers=encoder_layers,
292
+ encoder_d_model=encoder_d_model,
293
+ encoder_heads=encoder_heads,
294
+ encoder_dff=encoder_dff,
295
+ dropout_rate=dropout_rate,
296
+ pol_embedding_size=pol_embedding_size,
297
+ policy_d_model=policy_d_model,
298
+ val_embedding_size=val_embedding_size,
299
+ use_smolgen=use_smolgen,
300
+ smol_hidden_channels=smol_hidden_channels,
301
+ smol_hidden_sz=smol_hidden_sz,
302
+ smol_gen_sz=smol_gen_sz,
303
+ smol_activation=smol_activation,
304
+ )
305
+ super(BT4, self).__init__(config)
306
+
307
+ # Use config values (config is now guaranteed to exist)
308
+ embedding_size = config.embedding_size
309
+ embedding_dense_sz = config.embedding_dense_sz
310
+ encoder_layers = config.encoder_layers
311
+ encoder_d_model = config.encoder_d_model
312
+ encoder_heads = config.encoder_heads
313
+ encoder_dff = config.encoder_dff
314
+ dropout_rate = config.dropout_rate
315
+ pol_embedding_size = config.pol_embedding_size
316
+ policy_d_model = config.policy_d_model
317
+ val_embedding_size = config.val_embedding_size
318
+ use_smolgen = config.use_smolgen
319
+ smol_hidden_channels = config.smol_hidden_channels
320
+ smol_hidden_sz = config.smol_hidden_sz
321
+ smol_gen_sz = config.smol_gen_sz
322
+ smol_activation = config.smol_activation
323
  self.embedding_dense_sz = embedding_dense_sz
324
  # DeepNorm alpha used in embedding residual; default uses provided encoder_layers
325
  self.deepnorm_alpha = (2. * encoder_layers) ** -0.25