schrum2 commited on
Commit
1ef07aa
·
1 Parent(s): ff12ff1

Removing unnecessary files, reverting back to originals

Browse files
Files changed (4) hide show
  1. model_index.json +17 -7
  2. models/text_model.py +0 -206
  3. tokenizer.py +0 -147
  4. util/common_settings.py +0 -18
model_index.json CHANGED
@@ -1,10 +1,20 @@
1
  {
2
  "_class_name": "TextConditionalDDPMPipeline",
3
  "_diffusers_version": "0.32.2",
4
- "components": {
5
- "unet": { "type": "UNet2DConditionModel", "subfolder": "unet" },
6
- "text_encoder": { "type": "models.text_model.TransformerModel", "subfolder": "text_encoder" },
7
- "tokenizer": { "type": "Tokenizer", "file": "tokenizer.py" },
8
- "scheduler": { "type": "DDPMScheduler", "subfolder": "scheduler" }
9
- }
10
- }
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
  "_class_name": "TextConditionalDDPMPipeline",
3
  "_diffusers_version": "0.32.2",
4
+ "scheduler": [
5
+ "diffusers",
6
+ "DDPMScheduler"
7
+ ],
8
+ "text_encoder": [
9
+ "models.text_model",
10
+ "TransformerModel"
11
+ ],
12
+ "tokenizer": [
13
+ "tokenizer",
14
+ "Tokenizer"
15
+ ],
16
+ "unet": [
17
+ "diffusers",
18
+ "UNet2DConditionModel"
19
+ ]
20
+ }
models/text_model.py DELETED
@@ -1,206 +0,0 @@
1
- import argparse
2
- from xml.parsers.expat import model
3
- import torch
4
- import torch.nn as nn
5
- import math
6
- import os
7
- import json
8
- from safetensors.torch import save_file, load_file
9
- from tokenizer import Tokenizer
10
-
11
- def get_embeddings(batch_size, tokenizer, text_encoder, captions=None, neg_captions=None, device='cpu'):
12
- max_length = text_encoder.max_seq_length
13
- empty_ids = encode_token_captions([""] * batch_size, tokenizer, max_length, device=device)
14
- embeddings = text_encoder.get_embeddings(empty_ids)
15
-
16
- if(captions is not None):
17
- caption_ids = encode_token_captions(captions, tokenizer, max_length, device=device)
18
- caption_embeddings = text_encoder.get_embeddings(caption_ids)
19
- embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
20
-
21
- if(neg_captions is not None):
22
- neg_ids = encode_token_captions(neg_captions, tokenizer, max_length, device=device)
23
- neg_embeddings = text_encoder.get_embeddings(neg_ids)
24
- embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
25
-
26
- return embeddings.to(device)
27
-
28
- def encode_token_captions(captions, tokenizer, max_length, device='cpu'):
29
- caption_ids = []
30
- for caption in captions:
31
- tokens = tokenizer.encode(caption)
32
- caption_tokens = tokenizer.pad_sequence(tokens, max_length)
33
- caption_ids.append(torch.tensor(caption_tokens, dtype=torch.long).unsqueeze(0))
34
- return torch.cat(caption_ids, dim=0).to(device)
35
-
36
-
37
-
38
-
39
-
40
-
41
-
42
-
43
-
44
- # Transformer model for MLM training
45
-
46
- class TransformerModel(nn.Module):
47
- def __init__(self, vocab_size, embedding_dim, hidden_dim, tokenizer=None, num_heads=8, num_layers=4, max_seq_length=100):
48
- super().__init__()
49
- self.embedding_dim = embedding_dim
50
- self.vocab_size = vocab_size
51
- self.hidden_dim = hidden_dim
52
- self.num_heads = num_heads
53
- self.num_layers = num_layers
54
- self.max_seq_length = max_seq_length
55
-
56
- self.embedding = nn.Embedding(vocab_size, embedding_dim)
57
- self.positional_encoding = self.create_positional_encoding(max_seq_length, embedding_dim)
58
-
59
- encoder_layers = nn.TransformerEncoderLayer(
60
- d_model=embedding_dim,
61
- nhead=num_heads,
62
- dim_feedforward=hidden_dim,
63
- batch_first=True
64
- )
65
- self.transformer = nn.TransformerEncoder(encoder_layers, num_layers)
66
- self.fc = nn.Linear(embedding_dim, vocab_size)
67
-
68
- self.tokenizer = tokenizer
69
-
70
- def create_positional_encoding(self, max_seq_length, embedding_dim):
71
- # The implementation uses a sinusoidal positional encoding, which creates a unique pattern for each position in the sequence.
72
- # The frequencies create unique values, the sin/cos bounds values
73
- position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
74
- # Creates a set of divisors that create different frequencies
75
- div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
76
- pe = torch.zeros(max_seq_length, embedding_dim)
77
- # Even dimensions use sin, odd dimensions use cos
78
- pe[:, 0::2] = torch.sin(position * div_term)
79
- pe[:, 1::2] = torch.cos(position * div_term)
80
- return pe.unsqueeze(0)
81
-
82
- def get_embeddings(self, x):
83
- """ This gets the actual latent embedding vectors """
84
- # Ensure positional encoding is on the same device as input
85
- pe = self.positional_encoding[:, :x.size(1), :].to(x.device)
86
- # Embed input and add positional encoding
87
- embedded = self.embedding(x) + pe
88
- return self.transformer(embedded)
89
-
90
- def forward(self, x):
91
- """ This gets the token within the vocabulary """
92
- transformer_out = self.get_embeddings(x)
93
- # Project to vocabulary size
94
- return self.fc(transformer_out)
95
-
96
- def save_pretrained(self, save_directory):
97
- os.makedirs(save_directory, exist_ok=True)
98
-
99
- config = {
100
- "vocab_size": self.vocab_size,
101
- "embedding_dim": self.embedding_dim,
102
- "hidden_dim": self.hidden_dim,
103
- "num_heads": self.num_heads,
104
- "num_layers": self.num_layers,
105
- "max_seq_length": self.max_seq_length,
106
- }
107
- with open(os.path.join(save_directory, "config.json"), "w") as f:
108
- json.dump(config, f)
109
-
110
- # Save model weights
111
- save_file(self.state_dict(), os.path.join(save_directory, "model.safetensors"))
112
-
113
- # Save tokenizer if present
114
- if self.tokenizer is not None:
115
- self.tokenizer.save(os.path.join(save_directory, "tokenizer.pkl"))
116
-
117
- @classmethod
118
- def from_pretrained(cls, load_directory):
119
- with open(os.path.join(load_directory, "config.json")) as f:
120
- config = json.load(f)
121
-
122
- model = cls(**config)
123
-
124
- # Load weights
125
- state_dict = load_file(os.path.join(load_directory, "model.safetensors"))
126
- model.load_state_dict(state_dict)
127
-
128
- # Load tokenizer if available
129
- tokenizer_path = os.path.join(load_directory, "tokenizer.pkl")
130
- if os.path.exists(tokenizer_path):
131
- tokenizer = Tokenizer()
132
- tokenizer.load(tokenizer_path)
133
- model.tokenizer = tokenizer
134
-
135
- return model
136
-
137
- def print_architecture(self, inputs=None):
138
- parser = argparse.ArgumentParser()
139
- parser.add_argument("--model_path", type=str, required=True, help="Path to trained transformer model")
140
- parser.add_argument("--json", type=str, default="SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
141
- parser.add_argument("--num_samples", type=int, default=10, help="Number of captions to evaluate")
142
- parser.add_argument("--mask_prob", type=float, default=0.15, help="Probability of masking each token")
143
-
144
- parser.add_argument("--compare_checkpoints", action="store_true", default=False, help="Run comparison across all model checkpoints")
145
- args = parser.parse_args()
146
-
147
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
- model = TransformerModel.from_pretrained(args.model_path).to(device)
149
- print(f"Loaded model from {args.model_path}")
150
-
151
- import os
152
- import re
153
- import json
154
- import matplotlib.pyplot as plt
155
- from torchview import draw_graph
156
- import graphviz
157
-
158
- graph = draw_graph(
159
- model=model,
160
- input_data=inputs,
161
- expand_nested=False,
162
- #enable_output_shape=True,
163
- #roll_out="nested",
164
- depth=1
165
- )
166
-
167
- # Save plot
168
- filename = 'mlm_architecture'
169
- graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
170
- #graph.visual_graph.save('unet_architecture.dot')
171
-
172
- def save_architecture_pdf(self, filename="transformer_architecture.pdf", input_length=32):
173
- """Save a visualization of the model architecture as a PDF using torchview."""
174
- try:
175
- from torchview import draw_graph
176
- except ImportError:
177
- raise ImportError("torchview is required for model visualization. Install with 'pip install torchview'.")
178
- import torch
179
- import os
180
- # Create a dummy input of the correct type for the model
181
- captions = ["full floor. two coins. one pipe.", "floor with two gaps. one cannon. many enemies."]
182
- tensor = encode_token_captions(captions, self.tokenizer, self.max_seq_length, device=next(self.parameters()).device)
183
- input_length = tensor.size(1) if tensor.dim() > 1 else self.max_seq_length
184
-
185
- num_tokens_list = [len(self.tokenizer.encode(c)) for c in captions]
186
- input_length = max(num_tokens_list) if num_tokens_list else input_length
187
- dummy_input = torch.zeros((1, input_length), dtype=torch.long, device=next(self.parameters()).device)
188
-
189
- # Draw the graph and save as PNG
190
- graph = draw_graph(self, input_data=dummy_input, expand_nested=True, save_graph=True, filename=filename.replace('.pdf',''), directory=".", depth=2)
191
- png_file = filename.replace('.pdf', '.png')
192
- # Convert PNG to PDF
193
- if os.path.exists(png_file):
194
- try:
195
- from PIL import Image
196
- im = Image.open(png_file)
197
- im.save(filename, "PDF", resolution=100.0)
198
- print(f"Saved architecture PDF to {filename}")
199
- # Optionally, remove the PNG file
200
- os.remove(png_file)
201
- except ImportError:
202
- print(f"PIL not installed. Architecture saved as PNG: {png_file}")
203
- except Exception as e:
204
- print(f"Could not convert PNG to PDF: {e}")
205
- else:
206
- print(f"Could not find PNG file to convert: {png_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer.py DELETED
@@ -1,147 +0,0 @@
1
- import json
2
- import re
3
- from collections import Counter
4
- import pickle
5
- import argparse
6
-
7
- class Tokenizer:
8
- def __init__(self):
9
- self.special_tokens = ["[PAD]", "[MASK]"]
10
- self.vocab = {}
11
- self.token_to_id = {}
12
- self.id_to_token = {}
13
-
14
- def tokenize(self, text):
15
- # Match words, numbers, periods, and commas as separate tokens
16
- tokens = re.findall(r'\w+|[.,]|\[mask\]|\[pad\]', text.lower())
17
- # Restore MASK and PAD to all caps
18
- modified_list = []
19
- for s in tokens:
20
- modified_s = s.replace("[mask]", "[MASK]").replace("[pad]", "[PAD]")
21
- modified_list.append(modified_s)
22
- return modified_list
23
-
24
- def pad_sequence(self, tokens, length):
25
- """Pads tokenized sequences to length with a padding token (assumed to be '[PAD]')."""
26
- if len(tokens) > length:
27
- raise ValueError(f"Token sequence length {len(tokens)} exceeds specified length {length}.")
28
-
29
- pad_token = self.token_to_id["[PAD]"]
30
- return tokens + [pad_token] * (length - len(tokens))
31
-
32
- def build_vocab(self, dataset_path, min_freq=1):
33
- token_counter = Counter()
34
-
35
- with open(dataset_path, 'r') as f:
36
- data = json.load(f)
37
- for entry in data:
38
- caption = entry['caption']
39
- tokens = self.tokenize(caption)
40
- token_counter.update(tokens)
41
-
42
- # Keep tokens that meet the min frequency
43
- tokens = [tok for tok, count in token_counter.items() if count >= min_freq]
44
-
45
- # Ensure special tokens are always included
46
- all_tokens = self.special_tokens + sorted(tokens)
47
-
48
- # Build vocab dictionaries
49
- self.vocab = {tok: idx for idx, tok in enumerate(all_tokens)}
50
- self.token_to_id = self.vocab
51
- self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
52
-
53
- print(f"Vocabulary size: {len(self.vocab)}")
54
-
55
- def encode(self, text):
56
- tokens = self.tokenize(text)
57
- encoded = []
58
- for tok in tokens:
59
- if tok not in self.token_to_id:
60
- raise ValueError(f"Unknown token encountered: {tok} in {text}")
61
- encoded.append(self.token_to_id[tok])
62
- return encoded
63
-
64
- def encode_batch(self, texts, pad_to_length=None):
65
- """
66
- Encode a batch of texts into token IDs with padding to ensure uniform length.
67
-
68
- Args:
69
- texts (list): A list of strings to encode
70
- pad_to_length (int, optional): Length to pad all sequences to. If None,
71
- will pad to the length of the longest sequence.
72
-
73
- Returns:
74
- list: A list of lists, where each inner list contains the token IDs for a text
75
- """
76
- # Get the padding token ID
77
- pad_token = self.token_to_id["[PAD]"]
78
-
79
- # First encode all texts
80
- encoded_texts = []
81
- for text in texts:
82
- try:
83
- encoded = self.encode(text)
84
- encoded_texts.append(encoded)
85
- except ValueError as e:
86
- raise ValueError(f"Error encoding text: {text}. {str(e)}")
87
-
88
- # Determine padding length
89
- if pad_to_length is None:
90
- pad_to_length = max(len(seq) for seq in encoded_texts)
91
-
92
- # Pad sequences to uniform length
93
- padded_texts = []
94
- for seq in encoded_texts:
95
- if len(seq) > pad_to_length:
96
- # Truncate if too long
97
- padded_texts.append(seq[:pad_to_length])
98
- else:
99
- # Pad if too short
100
- padding = [pad_token] * (pad_to_length - len(seq))
101
- padded_texts.append(seq + padding)
102
-
103
- return padded_texts
104
-
105
- def decode(self, token_ids):
106
- return ' '.join(self.id_to_token[tok_id] for tok_id in token_ids)
107
-
108
- def save(self, path):
109
- with open(path, 'wb') as f:
110
- pickle.dump({'vocab': self.vocab}, f)
111
-
112
- def load(self, path):
113
- with open(path, 'rb') as f:
114
- data = pickle.load(f)
115
- self.vocab = data['vocab']
116
- self.token_to_id = self.vocab
117
- self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
118
-
119
- def get_vocab(self):
120
- return sorted(self.vocab.keys())
121
-
122
- def get_vocab_size(self):
123
- return len(self.vocab)
124
-
125
- if __name__ == "__main__":
126
- tokenizer = Tokenizer()
127
-
128
- parser = argparse.ArgumentParser(description="Tokenizer utility for saving and loading vocabularies.")
129
- parser.add_argument("action", choices=["save", "load"], help="Action to perform: 'save' or 'load'.")
130
- parser.add_argument("--json_file", type=str, default='Mario_LevelsAndCaptions.json', help="Path to the JSON file containing the dataset (required for 'save').")
131
- parser.add_argument("--pkl_file", type=str, default='Mario_Tokenizer.pkl', help="Path to the pickle file to save/load the tokenizer.")
132
-
133
- args = parser.parse_args()
134
-
135
- if args.action == "save":
136
- if not args.json_file:
137
- raise ValueError("The --json_file argument is required for the 'save' action.")
138
- tokenizer.build_vocab(args.json_file)
139
- tokenizer.save(args.pkl_file)
140
- elif args.action == "load":
141
- tokenizer.load(args.pkl_file)
142
-
143
- # Example usage
144
- #print(tokenizer.encode("floor with one gap. one enemy."))
145
- #print(tokenizer.get_vocab())
146
- #for id, token in tokenizer.id_to_token.items():
147
- # print(id,":",token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util/common_settings.py DELETED
@@ -1,18 +0,0 @@
1
-
2
- NUM_INFERENCE_STEPS = 30
3
- GUIDANCE_SCALE = 7.5
4
-
5
- MARIO_HEIGHT = 16
6
- MARIO_WIDTH = 16
7
-
8
- MARIO_TILE_PIXEL_DIM = 16
9
- MARIO_TILE_COUNT = 13
10
-
11
- LR_HEIGHT = 32
12
- LR_WIDTH = 32
13
-
14
- LR_TILE_PIXEL_DIM = 8
15
- LR_TILE_COUNT = 8
16
-
17
- MEGAMAN_HEIGHT = 14
18
- MEGAMAN_WIDTH = 16