Xuezha commited on
Commit
e53fb05
·
verified ·
1 Parent(s): d82b98e

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +2 -8
modeling.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
4
  from transformers.modeling_outputs import CausalLMOutputWithPast
 
5
 
6
  class MaskedSelfAttentionLayer(nn.Module):
7
  def __init__(self, embed_dim, num_heads):
@@ -133,14 +134,7 @@ class RecombinationTransformerLayer(nn.Module):
133
 
134
  return x
135
 
136
- class RecombinationTransformerConfig(PretrainedConfig):
137
- model_type = "RecombinationTransformer"
138
- def __init__(self, embed_dim=1024, num_heads=8, num_layers=6, vocab_size=151643, **kwargs):
139
- super().__init__(**kwargs)
140
- self.embed_dim = embed_dim
141
- self.num_heads = num_heads
142
- self.num_layers = num_layers
143
- self.vocab_size = vocab_size
144
 
145
  class RecombinationTransformerForCausalLM(PreTrainedModel):
146
  config_class = RecombinationTransformerConfig
 
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
4
  from transformers.modeling_outputs import CausalLMOutputWithPast
5
+ from configure import RecombinationTransformerConfig
6
 
7
  class MaskedSelfAttentionLayer(nn.Module):
8
  def __init__(self, embed_dim, num_heads):
 
134
 
135
  return x
136
 
137
+
 
 
 
 
 
 
 
138
 
139
  class RecombinationTransformerForCausalLM(PreTrainedModel):
140
  config_class = RecombinationTransformerConfig