File size: 8,985 Bytes
88ffa9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import argparse
from xml.parsers.expat import model
import torch
import torch.nn as nn
import math
import os
import json
from safetensors.torch import save_file, load_file
from tokenizer import Tokenizer

def get_embeddings(batch_size, tokenizer, text_encoder, captions=None, neg_captions=None, device='cpu'):
    max_length = text_encoder.max_seq_length
    empty_ids = encode_token_captions([""] * batch_size, tokenizer, max_length, device=device)
    embeddings = text_encoder.get_embeddings(empty_ids)

    if(captions is not None):
        caption_ids = encode_token_captions(captions, tokenizer, max_length, device=device)
        caption_embeddings = text_encoder.get_embeddings(caption_ids)
        embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
    
    if(neg_captions is not None):
        neg_ids = encode_token_captions(neg_captions, tokenizer, max_length, device=device)
        neg_embeddings = text_encoder.get_embeddings(neg_ids)
        embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
    
    return embeddings.to(device)

def encode_token_captions(captions, tokenizer, max_length, device='cpu'):
    caption_ids = []
    for caption in captions:
        tokens = tokenizer.encode(caption)
        caption_tokens = tokenizer.pad_sequence(tokens, max_length)
        caption_ids.append(torch.tensor(caption_tokens, dtype=torch.long).unsqueeze(0))
    return torch.cat(caption_ids, dim=0).to(device)









# Transformer model for MLM training

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, tokenizer=None, num_heads=8, num_layers=4, max_seq_length=100):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.max_seq_length = max_seq_length

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.positional_encoding = self.create_positional_encoding(max_seq_length, embedding_dim)

        encoder_layers = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layers, num_layers)
        self.fc = nn.Linear(embedding_dim, vocab_size)

        self.tokenizer = tokenizer

    def create_positional_encoding(self, max_seq_length, embedding_dim):
        # The implementation uses a sinusoidal positional encoding, which creates a unique pattern for each position in the sequence.
        # The frequencies create unique values, the sin/cos bounds values
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        # Creates a set of divisors that create different frequencies
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
        pe = torch.zeros(max_seq_length, embedding_dim)
        # Even dimensions use sin, odd dimensions use cos
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)

    def get_embeddings(self, x):
        """ This gets the actual latent embedding vectors """
        # Ensure positional encoding is on the same device as input
        pe = self.positional_encoding[:, :x.size(1), :].to(x.device)
        # Embed input and add positional encoding
        embedded = self.embedding(x) + pe
        return self.transformer(embedded)

    def forward(self, x):
        """ This gets the token within the vocabulary """
        transformer_out = self.get_embeddings(x)
        # Project to vocabulary size
        return self.fc(transformer_out)

    def save_pretrained(self, save_directory):
        os.makedirs(save_directory, exist_ok=True)

        config = {
            "vocab_size": self.vocab_size,
            "embedding_dim": self.embedding_dim,
            "hidden_dim": self.hidden_dim,
            "num_heads": self.num_heads,
            "num_layers": self.num_layers,
            "max_seq_length": self.max_seq_length,
        }
        with open(os.path.join(save_directory, "config.json"), "w") as f:
            json.dump(config, f)

        # Save model weights
        save_file(self.state_dict(), os.path.join(save_directory, "model.safetensors"))

        # Save tokenizer if present
        if self.tokenizer is not None:
            self.tokenizer.save(os.path.join(save_directory, "tokenizer.pkl"))

    @classmethod
    def from_pretrained(cls, load_directory):
        with open(os.path.join(load_directory, "config.json")) as f:
            config = json.load(f)

        model = cls(**config)

        # Load weights
        state_dict = load_file(os.path.join(load_directory, "model.safetensors"))
        model.load_state_dict(state_dict)

        # Load tokenizer if available
        tokenizer_path = os.path.join(load_directory, "tokenizer.pkl")
        if os.path.exists(tokenizer_path):
            tokenizer = Tokenizer()
            tokenizer.load(tokenizer_path)
            model.tokenizer = tokenizer

        return model
    
    def print_architecture(self, inputs=None):
        parser = argparse.ArgumentParser()
        parser.add_argument("--model_path", type=str, required=True, help="Path to trained transformer model")
        parser.add_argument("--json", type=str, default="SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
        parser.add_argument("--num_samples", type=int, default=10, help="Number of captions to evaluate")
        parser.add_argument("--mask_prob", type=float, default=0.15, help="Probability of masking each token")

        parser.add_argument("--compare_checkpoints", action="store_true", default=False, help="Run comparison across all model checkpoints")
        args = parser.parse_args()

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = TransformerModel.from_pretrained(args.model_path).to(device)
        print(f"Loaded model from {args.model_path}")

        import os
        import re
        import json
        import matplotlib.pyplot as plt
        from torchview import draw_graph
        import graphviz

        graph = draw_graph(
            model=model,
            input_data=inputs,
            expand_nested=False,
            #enable_output_shape=True,   
            #roll_out="nested",
            depth=1
        )

        # Save plot
        filename = 'mlm_architecture'
        graph.visual_graph.render(filename, format='pdf', cleanup=False)  # Cleanup removes intermediate files
        #graph.visual_graph.save('unet_architecture.dot')

    def save_architecture_pdf(self, filename="transformer_architecture.pdf", input_length=32):
        """Save a visualization of the model architecture as a PDF using torchview."""
        try:
            from torchview import draw_graph
        except ImportError:
            raise ImportError("torchview is required for model visualization. Install with 'pip install torchview'.")
        import torch
        import os
        # Create a dummy input of the correct type for the model
        captions = ["full floor. two coins. one pipe.", "floor with two gaps. one cannon. many enemies."]
        tensor = encode_token_captions(captions, self.tokenizer, self.max_seq_length, device=next(self.parameters()).device)
        input_length = tensor.size(1) if tensor.dim() > 1 else self.max_seq_length

        num_tokens_list = [len(self.tokenizer.encode(c)) for c in captions]
        input_length = max(num_tokens_list) if num_tokens_list else input_length
        dummy_input = torch.zeros((1, input_length), dtype=torch.long, device=next(self.parameters()).device)

        # Draw the graph and save as PNG
        graph = draw_graph(self, input_data=dummy_input, expand_nested=True, save_graph=True, filename=filename.replace('.pdf',''), directory=".", depth=2)
        png_file = filename.replace('.pdf', '.png')
        # Convert PNG to PDF
        if os.path.exists(png_file):
            try:
                from PIL import Image
                im = Image.open(png_file)
                im.save(filename, "PDF", resolution=100.0)
                print(f"Saved architecture PDF to {filename}")
                # Optionally, remove the PNG file
                os.remove(png_file)
            except ImportError:
                print(f"PIL not installed. Architecture saved as PNG: {png_file}")
            except Exception as e:
                print(f"Could not convert PNG to PDF: {e}")
        else:
            print(f"Could not find PNG file to convert: {png_file}")