schrum2 commited on
Commit
cc82bc6
·
verified ·
1 Parent(s): a09cfc1

Going to move this file

Browse files
Files changed (1) hide show
  1. text_model.py +0 -206
text_model.py DELETED
@@ -1,206 +0,0 @@
1
- import argparse
2
- from xml.parsers.expat import model
3
- import torch
4
- import torch.nn as nn
5
- import math
6
- import os
7
- import json
8
- from safetensors.torch import save_file, load_file
9
- from tokenizer import Tokenizer
10
-
11
- def get_embeddings(batch_size, tokenizer, text_encoder, captions=None, neg_captions=None, device='cpu'):
12
- max_length = text_encoder.max_seq_length
13
- empty_ids = encode_token_captions([""] * batch_size, tokenizer, max_length, device=device)
14
- embeddings = text_encoder.get_embeddings(empty_ids)
15
-
16
- if(captions is not None):
17
- caption_ids = encode_token_captions(captions, tokenizer, max_length, device=device)
18
- caption_embeddings = text_encoder.get_embeddings(caption_ids)
19
- embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
20
-
21
- if(neg_captions is not None):
22
- neg_ids = encode_token_captions(neg_captions, tokenizer, max_length, device=device)
23
- neg_embeddings = text_encoder.get_embeddings(neg_ids)
24
- embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
25
-
26
- return embeddings.to(device)
27
-
28
- def encode_token_captions(captions, tokenizer, max_length, device='cpu'):
29
- caption_ids = []
30
- for caption in captions:
31
- tokens = tokenizer.encode(caption)
32
- caption_tokens = tokenizer.pad_sequence(tokens, max_length)
33
- caption_ids.append(torch.tensor(caption_tokens, dtype=torch.long).unsqueeze(0))
34
- return torch.cat(caption_ids, dim=0).to(device)
35
-
36
-
37
-
38
-
39
-
40
-
41
-
42
-
43
-
44
- # Transformer model for MLM training
45
-
46
- class TransformerModel(nn.Module):
47
- def __init__(self, vocab_size, embedding_dim, hidden_dim, tokenizer=None, num_heads=8, num_layers=4, max_seq_length=100):
48
- super().__init__()
49
- self.embedding_dim = embedding_dim
50
- self.vocab_size = vocab_size
51
- self.hidden_dim = hidden_dim
52
- self.num_heads = num_heads
53
- self.num_layers = num_layers
54
- self.max_seq_length = max_seq_length
55
-
56
- self.embedding = nn.Embedding(vocab_size, embedding_dim)
57
- self.positional_encoding = self.create_positional_encoding(max_seq_length, embedding_dim)
58
-
59
- encoder_layers = nn.TransformerEncoderLayer(
60
- d_model=embedding_dim,
61
- nhead=num_heads,
62
- dim_feedforward=hidden_dim,
63
- batch_first=True
64
- )
65
- self.transformer = nn.TransformerEncoder(encoder_layers, num_layers)
66
- self.fc = nn.Linear(embedding_dim, vocab_size)
67
-
68
- self.tokenizer = tokenizer
69
-
70
- def create_positional_encoding(self, max_seq_length, embedding_dim):
71
- # The implementation uses a sinusoidal positional encoding, which creates a unique pattern for each position in the sequence.
72
- # The frequencies create unique values, the sin/cos bounds values
73
- position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
74
- # Creates a set of divisors that create different frequencies
75
- div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
76
- pe = torch.zeros(max_seq_length, embedding_dim)
77
- # Even dimensions use sin, odd dimensions use cos
78
- pe[:, 0::2] = torch.sin(position * div_term)
79
- pe[:, 1::2] = torch.cos(position * div_term)
80
- return pe.unsqueeze(0)
81
-
82
- def get_embeddings(self, x):
83
- """ This gets the actual latent embedding vectors """
84
- # Ensure positional encoding is on the same device as input
85
- pe = self.positional_encoding[:, :x.size(1), :].to(x.device)
86
- # Embed input and add positional encoding
87
- embedded = self.embedding(x) + pe
88
- return self.transformer(embedded)
89
-
90
- def forward(self, x):
91
- """ This gets the token within the vocabulary """
92
- transformer_out = self.get_embeddings(x)
93
- # Project to vocabulary size
94
- return self.fc(transformer_out)
95
-
96
- def save_pretrained(self, save_directory):
97
- os.makedirs(save_directory, exist_ok=True)
98
-
99
- config = {
100
- "vocab_size": self.vocab_size,
101
- "embedding_dim": self.embedding_dim,
102
- "hidden_dim": self.hidden_dim,
103
- "num_heads": self.num_heads,
104
- "num_layers": self.num_layers,
105
- "max_seq_length": self.max_seq_length,
106
- }
107
- with open(os.path.join(save_directory, "config.json"), "w") as f:
108
- json.dump(config, f)
109
-
110
- # Save model weights
111
- save_file(self.state_dict(), os.path.join(save_directory, "model.safetensors"))
112
-
113
- # Save tokenizer if present
114
- if self.tokenizer is not None:
115
- self.tokenizer.save(os.path.join(save_directory, "tokenizer.pkl"))
116
-
117
- @classmethod
118
- def from_pretrained(cls, load_directory):
119
- with open(os.path.join(load_directory, "config.json")) as f:
120
- config = json.load(f)
121
-
122
- model = cls(**config)
123
-
124
- # Load weights
125
- state_dict = load_file(os.path.join(load_directory, "model.safetensors"))
126
- model.load_state_dict(state_dict)
127
-
128
- # Load tokenizer if available
129
- tokenizer_path = os.path.join(load_directory, "tokenizer.pkl")
130
- if os.path.exists(tokenizer_path):
131
- tokenizer = Tokenizer()
132
- tokenizer.load(tokenizer_path)
133
- model.tokenizer = tokenizer
134
-
135
- return model
136
-
137
- def print_architecture(self, inputs=None):
138
- parser = argparse.ArgumentParser()
139
- parser.add_argument("--model_path", type=str, required=True, help="Path to trained transformer model")
140
- parser.add_argument("--json", type=str, default="SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
141
- parser.add_argument("--num_samples", type=int, default=10, help="Number of captions to evaluate")
142
- parser.add_argument("--mask_prob", type=float, default=0.15, help="Probability of masking each token")
143
-
144
- parser.add_argument("--compare_checkpoints", action="store_true", default=False, help="Run comparison across all model checkpoints")
145
- args = parser.parse_args()
146
-
147
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
- model = TransformerModel.from_pretrained(args.model_path).to(device)
149
- print(f"Loaded model from {args.model_path}")
150
-
151
- import os
152
- import re
153
- import json
154
- import matplotlib.pyplot as plt
155
- from torchview import draw_graph
156
- import graphviz
157
-
158
- graph = draw_graph(
159
- model=model,
160
- input_data=inputs,
161
- expand_nested=False,
162
- #enable_output_shape=True,
163
- #roll_out="nested",
164
- depth=1
165
- )
166
-
167
- # Save plot
168
- filename = 'mlm_architecture'
169
- graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
170
- #graph.visual_graph.save('unet_architecture.dot')
171
-
172
- def save_architecture_pdf(self, filename="transformer_architecture.pdf", input_length=32):
173
- """Save a visualization of the model architecture as a PDF using torchview."""
174
- try:
175
- from torchview import draw_graph
176
- except ImportError:
177
- raise ImportError("torchview is required for model visualization. Install with 'pip install torchview'.")
178
- import torch
179
- import os
180
- # Create a dummy input of the correct type for the model
181
- captions = ["full floor. two coins. one pipe.", "floor with two gaps. one cannon. many enemies."]
182
- tensor = encode_token_captions(captions, self.tokenizer, self.max_seq_length, device=next(self.parameters()).device)
183
- input_length = tensor.size(1) if tensor.dim() > 1 else self.max_seq_length
184
-
185
- num_tokens_list = [len(self.tokenizer.encode(c)) for c in captions]
186
- input_length = max(num_tokens_list) if num_tokens_list else input_length
187
- dummy_input = torch.zeros((1, input_length), dtype=torch.long, device=next(self.parameters()).device)
188
-
189
- # Draw the graph and save as PNG
190
- graph = draw_graph(self, input_data=dummy_input, expand_nested=True, save_graph=True, filename=filename.replace('.pdf',''), directory=".", depth=2)
191
- png_file = filename.replace('.pdf', '.png')
192
- # Convert PNG to PDF
193
- if os.path.exists(png_file):
194
- try:
195
- from PIL import Image
196
- im = Image.open(png_file)
197
- im.save(filename, "PDF", resolution=100.0)
198
- print(f"Saved architecture PDF to {filename}")
199
- # Optionally, remove the PNG file
200
- os.remove(png_file)
201
- except ImportError:
202
- print(f"PIL not installed. Architecture saved as PNG: {png_file}")
203
- except Exception as e:
204
- print(f"Could not convert PNG to PDF: {e}")
205
- else:
206
- print(f"Could not find PNG file to convert: {png_file}")