ghg
Browse files- modelLM.py +17 -7
modelLM.py
CHANGED
|
@@ -3,12 +3,17 @@ import torch.nn as nn
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from transformers.modeling_utils import PreTrainedModel
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
# Define your custom language model class
|
| 7 |
class OBILanguageModel(PreTrainedModel):
|
| 8 |
def __init__(self, config):
|
| 9 |
super(OBILanguageModel,self).__init__(config)
|
| 10 |
self.token_embedding_table = nn.Embedding(config.vocab_size, config.hidden_size) # Use length of SentencePiece vocab
|
| 11 |
self.position_embedding_table = nn.Embedding(config.block_size, config.hidden_size)
|
|
|
|
|
|
|
| 12 |
self.transformer = nn.Transformer(
|
| 13 |
d_model=config.hidden_size,
|
| 14 |
nhead=config.num_attention_heads,
|
|
@@ -21,16 +26,21 @@ class OBILanguageModel(PreTrainedModel):
|
|
| 21 |
self.ln1 = nn.LayerNorm(config.hidden_size)
|
| 22 |
self.ln2 = nn.LayerNorm(config.hidden_size)
|
| 23 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) # Use length of SentencePiece vocab
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
def forward(self, idx,
|
| 26 |
tok_emb = self.token_embedding_table(idx)
|
| 27 |
-
pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
x = tok_emb + pos_emb
|
| 29 |
-
|
| 30 |
-
# Assuming you need to add attention_mask here
|
| 31 |
-
if attention_mask is not None:
|
| 32 |
-
x *= attention_mask.unsqueeze(-1)
|
| 33 |
-
|
| 34 |
x = self.transformer(x, x)
|
| 35 |
x = self.ln1(x)
|
| 36 |
x = self.ln2(x)
|
|
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from transformers.modeling_utils import PreTrainedModel
|
| 5 |
|
| 6 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 7 |
+
|
| 8 |
+
|
| 9 |
# Define your custom language model class
|
| 10 |
class OBILanguageModel(PreTrainedModel):
|
| 11 |
def __init__(self, config):
|
| 12 |
super(OBILanguageModel,self).__init__(config)
|
| 13 |
self.token_embedding_table = nn.Embedding(config.vocab_size, config.hidden_size) # Use length of SentencePiece vocab
|
| 14 |
self.position_embedding_table = nn.Embedding(config.block_size, config.hidden_size)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
self.transformer = nn.Transformer(
|
| 18 |
d_model=config.hidden_size,
|
| 19 |
nhead=config.num_attention_heads,
|
|
|
|
| 26 |
self.ln1 = nn.LayerNorm(config.hidden_size)
|
| 27 |
self.ln2 = nn.LayerNorm(config.hidden_size)
|
| 28 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) # Use length of SentencePiece vocab
|
| 29 |
+
|
| 30 |
+
|
| 31 |
|
| 32 |
+
def forward(self, idx, targets=None):
|
| 33 |
tok_emb = self.token_embedding_table(idx)
|
| 34 |
+
# pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device=device))
|
| 35 |
+
pos_emb = None # Initialize pos_emb to None
|
| 36 |
+
try:
|
| 37 |
+
pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device='cpu'))
|
| 38 |
+
except IndexError as e:
|
| 39 |
+
# Print relevant information for debugging
|
| 40 |
+
print(f"IndexError: {e}")
|
| 41 |
+
print(f"idx.size(1): {idx.size(1)}")
|
| 42 |
+
print(f"Positional embedding table shape: {self.position_embedding_table.weight.shape}")
|
| 43 |
x = tok_emb + pos_emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
x = self.transformer(x, x)
|
| 45 |
x = self.ln1(x)
|
| 46 |
x = self.ln2(x)
|