schrum2 commited on
Commit
88ffa9a
·
verified ·
1 Parent(s): 19248d9

Hugging face really wants to look for this code here

Browse files
Files changed (1) hide show
  1. text_encoder/text_model.py +206 -0
text_encoder/text_model.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from xml.parsers.expat import model
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+ import os
7
+ import json
8
+ from safetensors.torch import save_file, load_file
9
+ from tokenizer import Tokenizer
10
+
11
+ def get_embeddings(batch_size, tokenizer, text_encoder, captions=None, neg_captions=None, device='cpu'):
12
+ max_length = text_encoder.max_seq_length
13
+ empty_ids = encode_token_captions([""] * batch_size, tokenizer, max_length, device=device)
14
+ embeddings = text_encoder.get_embeddings(empty_ids)
15
+
16
+ if(captions is not None):
17
+ caption_ids = encode_token_captions(captions, tokenizer, max_length, device=device)
18
+ caption_embeddings = text_encoder.get_embeddings(caption_ids)
19
+ embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
20
+
21
+ if(neg_captions is not None):
22
+ neg_ids = encode_token_captions(neg_captions, tokenizer, max_length, device=device)
23
+ neg_embeddings = text_encoder.get_embeddings(neg_ids)
24
+ embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
25
+
26
+ return embeddings.to(device)
27
+
28
+ def encode_token_captions(captions, tokenizer, max_length, device='cpu'):
29
+ caption_ids = []
30
+ for caption in captions:
31
+ tokens = tokenizer.encode(caption)
32
+ caption_tokens = tokenizer.pad_sequence(tokens, max_length)
33
+ caption_ids.append(torch.tensor(caption_tokens, dtype=torch.long).unsqueeze(0))
34
+ return torch.cat(caption_ids, dim=0).to(device)
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+ # Transformer model for MLM training
45
+
46
+ class TransformerModel(nn.Module):
47
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, tokenizer=None, num_heads=8, num_layers=4, max_seq_length=100):
48
+ super().__init__()
49
+ self.embedding_dim = embedding_dim
50
+ self.vocab_size = vocab_size
51
+ self.hidden_dim = hidden_dim
52
+ self.num_heads = num_heads
53
+ self.num_layers = num_layers
54
+ self.max_seq_length = max_seq_length
55
+
56
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
57
+ self.positional_encoding = self.create_positional_encoding(max_seq_length, embedding_dim)
58
+
59
+ encoder_layers = nn.TransformerEncoderLayer(
60
+ d_model=embedding_dim,
61
+ nhead=num_heads,
62
+ dim_feedforward=hidden_dim,
63
+ batch_first=True
64
+ )
65
+ self.transformer = nn.TransformerEncoder(encoder_layers, num_layers)
66
+ self.fc = nn.Linear(embedding_dim, vocab_size)
67
+
68
+ self.tokenizer = tokenizer
69
+
70
+ def create_positional_encoding(self, max_seq_length, embedding_dim):
71
+ # The implementation uses a sinusoidal positional encoding, which creates a unique pattern for each position in the sequence.
72
+ # The frequencies create unique values, the sin/cos bounds values
73
+ position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
74
+ # Creates a set of divisors that create different frequencies
75
+ div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
76
+ pe = torch.zeros(max_seq_length, embedding_dim)
77
+ # Even dimensions use sin, odd dimensions use cos
78
+ pe[:, 0::2] = torch.sin(position * div_term)
79
+ pe[:, 1::2] = torch.cos(position * div_term)
80
+ return pe.unsqueeze(0)
81
+
82
+ def get_embeddings(self, x):
83
+ """ This gets the actual latent embedding vectors """
84
+ # Ensure positional encoding is on the same device as input
85
+ pe = self.positional_encoding[:, :x.size(1), :].to(x.device)
86
+ # Embed input and add positional encoding
87
+ embedded = self.embedding(x) + pe
88
+ return self.transformer(embedded)
89
+
90
+ def forward(self, x):
91
+ """ This gets the token within the vocabulary """
92
+ transformer_out = self.get_embeddings(x)
93
+ # Project to vocabulary size
94
+ return self.fc(transformer_out)
95
+
96
+ def save_pretrained(self, save_directory):
97
+ os.makedirs(save_directory, exist_ok=True)
98
+
99
+ config = {
100
+ "vocab_size": self.vocab_size,
101
+ "embedding_dim": self.embedding_dim,
102
+ "hidden_dim": self.hidden_dim,
103
+ "num_heads": self.num_heads,
104
+ "num_layers": self.num_layers,
105
+ "max_seq_length": self.max_seq_length,
106
+ }
107
+ with open(os.path.join(save_directory, "config.json"), "w") as f:
108
+ json.dump(config, f)
109
+
110
+ # Save model weights
111
+ save_file(self.state_dict(), os.path.join(save_directory, "model.safetensors"))
112
+
113
+ # Save tokenizer if present
114
+ if self.tokenizer is not None:
115
+ self.tokenizer.save(os.path.join(save_directory, "tokenizer.pkl"))
116
+
117
+ @classmethod
118
+ def from_pretrained(cls, load_directory):
119
+ with open(os.path.join(load_directory, "config.json")) as f:
120
+ config = json.load(f)
121
+
122
+ model = cls(**config)
123
+
124
+ # Load weights
125
+ state_dict = load_file(os.path.join(load_directory, "model.safetensors"))
126
+ model.load_state_dict(state_dict)
127
+
128
+ # Load tokenizer if available
129
+ tokenizer_path = os.path.join(load_directory, "tokenizer.pkl")
130
+ if os.path.exists(tokenizer_path):
131
+ tokenizer = Tokenizer()
132
+ tokenizer.load(tokenizer_path)
133
+ model.tokenizer = tokenizer
134
+
135
+ return model
136
+
137
+ def print_architecture(self, inputs=None):
138
+ parser = argparse.ArgumentParser()
139
+ parser.add_argument("--model_path", type=str, required=True, help="Path to trained transformer model")
140
+ parser.add_argument("--json", type=str, default="SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
141
+ parser.add_argument("--num_samples", type=int, default=10, help="Number of captions to evaluate")
142
+ parser.add_argument("--mask_prob", type=float, default=0.15, help="Probability of masking each token")
143
+
144
+ parser.add_argument("--compare_checkpoints", action="store_true", default=False, help="Run comparison across all model checkpoints")
145
+ args = parser.parse_args()
146
+
147
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
+ model = TransformerModel.from_pretrained(args.model_path).to(device)
149
+ print(f"Loaded model from {args.model_path}")
150
+
151
+ import os
152
+ import re
153
+ import json
154
+ import matplotlib.pyplot as plt
155
+ from torchview import draw_graph
156
+ import graphviz
157
+
158
+ graph = draw_graph(
159
+ model=model,
160
+ input_data=inputs,
161
+ expand_nested=False,
162
+ #enable_output_shape=True,
163
+ #roll_out="nested",
164
+ depth=1
165
+ )
166
+
167
+ # Save plot
168
+ filename = 'mlm_architecture'
169
+ graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
170
+ #graph.visual_graph.save('unet_architecture.dot')
171
+
172
+ def save_architecture_pdf(self, filename="transformer_architecture.pdf", input_length=32):
173
+ """Save a visualization of the model architecture as a PDF using torchview."""
174
+ try:
175
+ from torchview import draw_graph
176
+ except ImportError:
177
+ raise ImportError("torchview is required for model visualization. Install with 'pip install torchview'.")
178
+ import torch
179
+ import os
180
+ # Create a dummy input of the correct type for the model
181
+ captions = ["full floor. two coins. one pipe.", "floor with two gaps. one cannon. many enemies."]
182
+ tensor = encode_token_captions(captions, self.tokenizer, self.max_seq_length, device=next(self.parameters()).device)
183
+ input_length = tensor.size(1) if tensor.dim() > 1 else self.max_seq_length
184
+
185
+ num_tokens_list = [len(self.tokenizer.encode(c)) for c in captions]
186
+ input_length = max(num_tokens_list) if num_tokens_list else input_length
187
+ dummy_input = torch.zeros((1, input_length), dtype=torch.long, device=next(self.parameters()).device)
188
+
189
+ # Draw the graph and save as PNG
190
+ graph = draw_graph(self, input_data=dummy_input, expand_nested=True, save_graph=True, filename=filename.replace('.pdf',''), directory=".", depth=2)
191
+ png_file = filename.replace('.pdf', '.png')
192
+ # Convert PNG to PDF
193
+ if os.path.exists(png_file):
194
+ try:
195
+ from PIL import Image
196
+ im = Image.open(png_file)
197
+ im.save(filename, "PDF", resolution=100.0)
198
+ print(f"Saved architecture PDF to {filename}")
199
+ # Optionally, remove the PNG file
200
+ os.remove(png_file)
201
+ except ImportError:
202
+ print(f"PIL not installed. Architecture saved as PNG: {png_file}")
203
+ except Exception as e:
204
+ print(f"Could not convert PNG to PDF: {e}")
205
+ else:
206
+ print(f"Could not find PNG file to convert: {png_file}")