schrum2 commited on
Commit
ef6bba8
·
verified ·
1 Parent(s): 4c0a730

Deleting directories, moving files into root

Browse files
util/common_settings.py DELETED
@@ -1,18 +0,0 @@
1
-
2
- NUM_INFERENCE_STEPS = 30
3
- GUIDANCE_SCALE = 7.5
4
-
5
- MARIO_HEIGHT = 16
6
- MARIO_WIDTH = 16
7
-
8
- MARIO_TILE_PIXEL_DIM = 16
9
- MARIO_TILE_COUNT = 13
10
-
11
- LR_HEIGHT = 32
12
- LR_WIDTH = 32
13
-
14
- LR_TILE_PIXEL_DIM = 8
15
- LR_TILE_COUNT = 8
16
-
17
- MEGAMAN_HEIGHT = 14
18
- MEGAMAN_WIDTH = 16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util/naming_conventions.py DELETED
@@ -1,29 +0,0 @@
1
- model_name_map = [
2
- ("Mar1and2-conditional-regular", "MLM-regular"),
3
- ("Mar1and2-conditional-absence", "MLM-absence"),
4
- ("Mar1and2-conditional-negative", "MLM-negative"),
5
- ("Mar1and2-conditional-MiniLM-regular", "MiniLM-single-regular"),
6
- ("Mar1and2-conditional-MiniLM-absence", "MiniLM-single-absence"),
7
- ("Mar1and2-conditional-MiniLM-negative", "MiniLM-single-negative"),
8
- ("Mar1and2-conditional-MiniLMsplit-regular", "MiniLM-multiple-regular"),
9
- ("Mar1and2-conditional-MiniLMsplit-absence", "MiniLM-multiple-absence"),
10
- ("Mar1and2-conditional-MiniLMsplit-negative", "MiniLM-multiple-negative"),
11
- ("Mar1and2-conditional-GTE-regular", "GTE-single-regular"),
12
- ("Mar1and2-conditional-GTE-absence", "GTE-single-absence"),
13
- ("Mar1and2-conditional-GTE-negative", "GTE-single-negative"),
14
- ("Mar1and2-conditional-GTEsplit-regular", "GTE-multiple-regular"),
15
- ("Mar1and2-conditional-GTEsplit-absence", "GTE-multiple-absence"),
16
- ("Mar1and2-conditional-GTEsplit-negative", "GTE-multiple-negative"),
17
- ("Mar1and2-fdm-MiniLM-regular", "FDM-MiniLM-regular"),
18
- ("Mar1and2-fdm-MiniLM-absence", "FDM-MiniLM-absence"),
19
- ("Mar1and2-fdm-GTE-regular", "FDM-GTE-regular"),
20
- ("Mar1and2-fdm-GTE-absence", "FDM-GTE-absence"),
21
- ("Mar1and2-wgan", "WGAN"),
22
- ("Mar1and2-unconditional", "Unconditional"),
23
- ("MarioGPT_metrics", "MarioGPT"),
24
- ]
25
-
26
- def get_model_name_map_and_order():
27
- mapping = dict(model_name_map)
28
- order = [v for k, v in model_name_map]
29
- return mapping, order
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util/plotter.py DELETED
@@ -1,173 +0,0 @@
1
- # Track changes in loss and learning rate during execution
2
- import argparse
3
- import matplotlib
4
- import matplotlib.pyplot as plt
5
- import os
6
- import time
7
- import json
8
- import tempfile
9
- import shutil
10
- from pathlib import Path
11
-
12
-
13
- def parse_args():
14
- parser = argparse.ArgumentParser(description="Train a text-conditional diffusion model for tile-based level generation")
15
-
16
- # Dataset args
17
- parser.add_argument("--log_file", type=str, default=None, help="The the filepath of the file to get the data from")
18
- parser.add_argument("--left_key", type=str, default=None, help="The key for the left y-axis")
19
- parser.add_argument("--right_key", type=str, default=None, help="The key for the right y-axis")
20
- parser.add_argument("--left_label", type=str, default=None, help="The label for the left y-axis")
21
- parser.add_argument("--right_label", type=str, default=None, help="The label for the right y-axis")
22
- parser.add_argument("--output_png", type=str, default="output.png", help="The output png file")
23
- parser.add_argument("--update_interval", type=int, default=1.0, help="The update inteval in epochs")
24
- parser.add_argument("--start_point", type=int, default=None, help="The start point for the plot")
25
-
26
- return parser.parse_args()
27
-
28
-
29
- def main():
30
- args = parse_args()
31
-
32
- log_file = args.log_file
33
- left_key = args.left_key
34
- right_key = args.right_key
35
- left_label = args.left_label
36
- right_label = args.right_label
37
- output_png = args.output_png
38
- update_interval = args.update_interval
39
- start_point = args.start_point
40
-
41
- general_update_plot(log_file, left_key, right_key, left_label, right_label, output_png, update_interval=update_interval, startPoint=start_point)
42
-
43
-
44
- def general_update_plot(log_file, left_key, right_key, left_label, right_label, output_png, update_interval=1.0, startPoint=None):
45
- log_dir = os.path.dirname(log_file)
46
-
47
- # Create figure here and ensure it's closed
48
- fig = plt.figure(figsize=(10, 6))
49
- ax = fig.add_subplot(111)
50
-
51
- try:
52
- if os.path.exists(log_file):
53
- with open(log_file, 'r') as f:
54
- data = [json.loads(line) for line in f if line.strip()]
55
-
56
- if not data:
57
- return
58
-
59
- if startPoint is not None:
60
- data = [entry for entry in data if entry.get('epoch', 0) >= startPoint]
61
-
62
- if not data:
63
- return
64
-
65
- epochs = [entry.get('epoch', 0) for entry in data]
66
- left = [entry.get(left_key, 0) for entry in data]
67
-
68
- # For right axis (e.g., lr), only include points where right_key exists
69
- right_points = [(entry.get('epoch', 0), entry.get(right_key))
70
- for entry in data if right_key in entry]
71
- if right_points:
72
- right_epochs, right_values = zip(*right_points)
73
- else:
74
- right_epochs, right_values = [], []
75
-
76
- # Clear axis
77
- ax.clear()
78
-
79
- # Plot both metrics on the same axis
80
- ax.plot(epochs, left, 'b-', label=left_label)
81
- if right_epochs:
82
- ax.plot(right_epochs, right_values, 'r-', label=right_label)
83
-
84
- ax.set_xlabel('Epoch')
85
- ax.set_ylabel(left_label) # "Loss" as y-axis label
86
- ax.set_title('Training Progress')
87
- ax.legend(loc='upper left')
88
- #Limit x-axis to startPoint if provided
89
- if startPoint is not None:
90
- ax.set_xlim(left=startPoint)
91
- fig.tight_layout()
92
-
93
- # Use the stored base directory instead of getting it from log_file
94
- if os.path.isabs(output_png) or os.path.dirname(output_png):
95
- output_path = output_png
96
- else:
97
- output_path = os.path.join(log_dir, output_png)
98
-
99
- save_figure_safely(fig, output_path)
100
- finally:
101
- plt.close(fig) # Ensure figure is closed even if an error occurs
102
-
103
- def save_figure_safely(fig, output_path):
104
- """Save figure to a temporary file first, then move it to the final location"""
105
- output_path = str(Path(output_path)) # Convert to string path
106
-
107
- # Create temporary file with .png extension
108
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
109
- tmp_path = tmp_file.name
110
-
111
- try:
112
- # Save to temporary file
113
- fig.savefig(tmp_path)
114
-
115
- # Create output directory if it doesn't exist
116
- os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
117
-
118
- # Try to move the file to final destination
119
- # If move fails, try to copy and then delete
120
- try:
121
- shutil.move(tmp_path, output_path)
122
- except OSError:
123
- shutil.copy2(tmp_path, output_path)
124
- os.unlink(tmp_path)
125
- except Exception as e:
126
- # Clean up temporary file if anything goes wrong
127
- if os.path.exists(tmp_path):
128
- os.unlink(tmp_path)
129
- raise e
130
-
131
- class Plotter:
132
- def __init__(self, log_file, update_interval=1.0, left_key='loss', right_key='lr',
133
- left_label='Loss', right_label='Learning Rate', output_png='training_progress.png'):
134
- self.log_dir = os.path.dirname(log_file)
135
- self.log_file = log_file
136
- self.update_interval = update_interval
137
- self.running = True
138
- self.output_png = output_png
139
- self.left_key = left_key
140
- self.right_key = right_key
141
- self.left_label = left_label
142
- self.right_label = right_label
143
-
144
- matplotlib.use('Agg')
145
-
146
- def __enter__(self):
147
- return self
148
-
149
- def __exit__(self, exc_type, exc_val, exc_tb):
150
- self.stop_plotting()
151
-
152
- def __del__(self):
153
- self.stop_plotting()
154
-
155
- def update_plot(self):
156
- general_update_plot(self.log_file, self.left_key, self.right_key,
157
- self.left_label, self.right_label, self.output_png,
158
- update_interval=self.update_interval)
159
-
160
- def start_plotting(self):
161
- print("Starting plotting in background")
162
- while self.running:
163
- self.update_plot()
164
- time.sleep(self.update_interval)
165
-
166
- def stop_plotting(self):
167
- if hasattr(self, 'running'): # Check if already stopped
168
- self.running = False
169
- self.update_plot()
170
- print("Plotting stopped")
171
-
172
- if __name__ == "__main__":
173
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util/sampler.py DELETED
@@ -1,473 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from typing import List, Optional, Tuple, Union
5
-
6
- import os
7
- import subprocess
8
- import tempfile
9
-
10
- import numpy as np
11
- import torch
12
- from PIL.Image import Image
13
- from tqdm import tqdm
14
- from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper
15
-
16
-
17
- from mario_gpt.lm.base import BaseMarioLM
18
- from mario_gpt.prompter import Prompter
19
- from mario_gpt.simulator import Simulator
20
- from mario_gpt.utils import (
21
- convert_level_to_png,
22
- load_level,
23
- save_level,
24
- trim_level,
25
- view_level,
26
- )
27
-
28
- def scene_to_ascii(scene, id_to_char, shorten: bool = True) -> List[str]:
29
- """
30
- Convert JSON scene files from a list of lists of ints
31
- to a list of ASCII strings using id_to_char mapping.
32
- If shorten is True, only the last 15 rows are kept.
33
- Args:
34
- scene: List[List[int]] - 2D array of tile IDs
35
- id_to_char: Dict[int, str] - mapping from tile ID to ASCII character
36
- shorten: bool - If True, will shorten the output to only include the first 15 rows
37
- so A* Mario (for SNES graphics) to run without glitching
38
- Returns:
39
- List[str]: List of strings, each representing a row in ASCII
40
- """
41
- if shorten and len(scene) > 15:
42
- scene = scene[-15:] # Keep only the last 15 rows
43
- return ["".join(id_to_char[num] for num in row) for row in scene]
44
-
45
- @dataclass
46
- class SampleOutput:
47
- level: Optional[List[str]]
48
- prompt: Optional[str] = None
49
- img: Optional[Image] = None
50
- sample_predictions_str: Optional[List[str]] = None
51
- sample_predictions_img: Optional[Image] = None
52
- level_tensor: Optional[torch.Tensor] = None
53
- sample_predictions_tensor: Optional[torch.Tensor] = None
54
- # Uses MarioEval graphics for rendering levels when True
55
- use_snes_graphics: bool = False
56
-
57
- @classmethod
58
- def create(
59
- cls,
60
- level_tensor: torch.Tensor,
61
- sample_predictions_tensor: torch.Tensor,
62
- tokenizer,
63
- prompter: Optional[Prompter] = None,
64
- ) -> SampleOutput:
65
- # batch = 1
66
- level = None
67
- img = None
68
-
69
- try:
70
- level = view_level(level_tensor, tokenizer)
71
- img = convert_level_to_png(level)[0]
72
- except Exception as e:
73
- print(
74
- f"Failed to generate string or image representation for full level! Got error {e}"
75
- )
76
- level = None
77
- img = None
78
- try:
79
- sample_predictions_str = view_level(sample_predictions_tensor, tokenizer)
80
- sample_predictions_img = convert_level_to_png(sample_predictions_str)[0]
81
- except Exception as e:
82
- print(
83
- f"Failed to generate string or image representation for sampled predictions! Got error {e}"
84
- )
85
- sample_predictions_str = None
86
- sample_predictions_img = None
87
-
88
- prompt = None
89
- if prompter is not None:
90
- prompt = prompter(level_tensor)[0]
91
-
92
- return SampleOutput(
93
- level,
94
- prompt,
95
- img,
96
- sample_predictions_str,
97
- sample_predictions_img,
98
- level_tensor,
99
- sample_predictions_tensor,
100
- )
101
-
102
- @classmethod
103
- def from_level_predictions(
104
- cls,
105
- level: torch.Tensor,
106
- sample_predictions: torch.Tensor,
107
- tokenizer,
108
- prompter: Optional[Prompter] = None,
109
- ) -> Union[SampleOutput, List[SampleOutput]]:
110
- level_tensor = trim_level(level).squeeze().detach().cpu()
111
- sample_predictions_tensor = (
112
- trim_level(sample_predictions).squeeze().detach().cpu()
113
- )
114
-
115
- if len(level_tensor.shape) == 1:
116
- return SampleOutput.create(
117
- level_tensor, sample_predictions_tensor, tokenizer, prompter
118
- )
119
-
120
- out = []
121
- for _level_tensor, _sample_predictions_tensor in zip(
122
- level_tensor, sample_predictions_tensor
123
- ):
124
- sample_output = SampleOutput.create(
125
- _level_tensor, _sample_predictions_tensor, tokenizer, prompter
126
- )
127
- out.append(sample_output)
128
- return out
129
-
130
- def save(self, filename: str) -> str:
131
- save_level(self.level, filename)
132
-
133
- @classmethod
134
- def load(cls, filename: str) -> SampleOutput:
135
- level = load_level(filename)
136
- return SampleOutput(level=level)
137
-
138
- def play(self, game="mario", level_idx=None, dataset_path=None):
139
- """
140
- Play the level using the specified game engine.
141
- game: "mario" (default) or "loderunner"
142
- """
143
- if game == "loderunner":
144
- import tempfile, json
145
- # Convert self.level (list of strings) to Lode Runner JSON format
146
- scene = [[c for c in row] for row in self.level]
147
- lr_json = [{
148
- "scene": scene,
149
- "caption": ""
150
- }]
151
- with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
152
- json.dump(lr_json, tmp)
153
- tmp_path = tmp.name
154
- import sys, os
155
- #sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
156
- from LodeRunner.loderunner import main
157
- tmp_path = tmp_path if dataset_path is None else dataset_path
158
- print(f"Playing Lode Runner level interactively -- {tmp_path}!")
159
- main.play_lr_level(tmp_path, level_index=level_idx if level_idx is not None else 1)
160
- else:
161
- if self.use_snes_graphics:
162
- simulator = CustomSimulator(level=self.level, jar_path="MarioEval.jar")
163
- else:
164
- simulator = CustomSimulator(level=self.level, jar_path="NESMarioEval.jar")
165
- simulator.interactive()
166
-
167
- def run_astar(self, render=True):
168
- if self.use_snes_graphics:
169
- simulator = CustomSimulator(level=self.level, jar_path="MarioEval.jar")
170
- else:
171
- simulator = CustomSimulator(level=self.level, jar_path="NESMarioEval.jar")
172
- return simulator.astar(render)
173
-
174
- class CustomSimulator:
175
- """
176
- The classic Mario simulator used by MarioGPT is generally,
177
- better, but it doesn't return any information about
178
- Mario's performance. The main point of this simulator
179
- is that information about the performance of the agent
180
- is printed to the console (though I still need a way
181
- to caption and return that information)
182
- """
183
-
184
- def __init__(self, level, jar_path="MarioEval.jar"):
185
- while len(level) > 15:
186
- level.pop(0)
187
- # For some reason, my older A* agent
188
- # crashes on Mario levels with 16 rows or more
189
-
190
- self.level = level
191
- self.jar_path = jar_path
192
-
193
- def interactive(self):
194
- t = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
195
- save_level(self.level, t.name)
196
- print(f"Playing level interactively -- {t.name}!")
197
- _ = subprocess.run(
198
- ["java", "-jar", self.jar_path, "human", t.name, "human"],
199
- stdout=subprocess.PIPE,
200
- stderr=subprocess.PIPE,
201
- )
202
- t.close()
203
- os.unlink(t.name)
204
-
205
- def astar(self, render: bool = True):
206
- t = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
207
- save_level(self.level, t.name)
208
- print(f"Running Astar agent on level! -- {t.name}")
209
- render_str = "human" if render else "norender"
210
- result = subprocess.run(
211
- ["java", "-jar", self.jar_path, "astar", t.name, render_str],
212
- stdout=subprocess.PIPE,
213
- stderr=subprocess.PIPE,
214
- )
215
- t.close()
216
- os.unlink(t.name)
217
- # Combine stdout and stderr, decode to string, and return
218
- output = result.stdout.decode("utf-8") + result.stderr.decode("utf-8")
219
- return output
220
-
221
- def save_level(level: List[str], filename: str):
222
- concatenated = "\n".join(level)
223
- with open(filename, "w") as f:
224
- f.write(concatenated)
225
- return filename
226
-
227
- class GPTSampler:
228
- def __init__(
229
- self,
230
- mario_lm: BaseMarioLM,
231
- temperature: float = 2.0,
232
- top_k: int = 16,
233
- context_len: int = 700,
234
- use_tqdm: bool = False,
235
- use_argmax: bool = False,
236
- ):
237
- self.mario_lm = mario_lm
238
- self.temperature = temperature
239
- self.top_k = top_k
240
- self.context_len = context_len
241
- self.use_tqdm = use_tqdm
242
- self.use_argmax = use_argmax
243
- self.logits_processor = LogitsProcessorList()
244
- self.logits_warper = LogitsProcessorList(
245
- [
246
- TopKLogitsWarper(top_k), # number of characters
247
- TemperatureLogitsWarper(temperature),
248
- ]
249
- )
250
-
251
- @property
252
- def device(self) -> torch.device:
253
- return self.mario_lm.device
254
-
255
- def step(
256
- self,
257
- seed: torch.Tensor,
258
- encoder_hidden_states: torch.Tensor,
259
- ) -> Tuple[torch.Tensor, torch.Tensor]:
260
- with torch.no_grad():
261
- attention_mask = torch.ones_like(seed).to(seed.device)
262
- input_ids = seed
263
- out = self.mario_lm.lm(
264
- input_ids=input_ids,
265
- attention_mask=attention_mask,
266
- encoder_hidden_states=encoder_hidden_states,
267
- token_type_ids=None,
268
- )
269
- logits = out.logits.detach()
270
- if len(logits.shape) == 2:
271
- logits = logits.view(1, 1, -1)
272
- next_token_logits = logits[:, -1, :]
273
-
274
- if self.use_argmax:
275
- next_tokens = next_token_logits.argmax(-1)
276
- else:
277
- next_token_scores = self.logits_processor(input_ids, next_token_logits)
278
- next_token_scores = self.logits_warper(input_ids, next_token_scores)
279
- probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
280
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
281
- return next_tokens, encoder_hidden_states
282
-
283
- def sample(
284
- self,
285
- seed: Union[Optional[torch.Tensor], Optional[SampleOutput]] = None,
286
- prompts: Optional[List[str]] = None,
287
- num_steps: int = 1,
288
- encoder_hidden_states: torch.Tensor = None,
289
- return_tensor: bool = False,
290
- ):
291
- self.mario_lm.eval()
292
- context_len = self.context_len - 28
293
- with torch.no_grad():
294
- if seed is None:
295
- seed = self.mario_lm.generate_seed(1, batch_size=len(prompts)).to(
296
- self.device
297
- )
298
- out_tensor = seed.to(self.device)
299
- elif isinstance(seed, SampleOutput):
300
- out_tensor = seed.level_tensor.to(self.device).squeeze()
301
- else:
302
- out_tensor = seed.to(self.device).squeeze()
303
- if len(out_tensor.shape) < 2:
304
- # if we pass in a single seed vector, then we repeat for each prompt
305
- # Otherwise, we treat inputs as separate seed-prompt pairs
306
- out_tensor = out_tensor.view(1, -1).repeat(len(prompts), 1)
307
- if encoder_hidden_states is None:
308
- if prompts is not None:
309
- encoder_hidden_states = torch.stack(
310
- [
311
- self.mario_lm.prompter.output_hidden(prompt)
312
- for prompt in prompts
313
- ]
314
- )
315
- else:
316
- encoder_hidden_states = torch.stack(
317
- [
318
- self.mario_lm.prompter(sample_prompt=True)[1]
319
- for _ in range(seed.shape[0])
320
- ]
321
- )
322
- encoder_hidden_states = encoder_hidden_states.to(
323
- self.device
324
- ) # b x 1 x hidden_dim
325
- encoder_hidden_states = encoder_hidden_states.view(
326
- out_tensor.shape[0], 1, -1
327
- )
328
- if not self.use_tqdm:
329
- bar = np.arange(num_steps)
330
- else:
331
- bar = tqdm(np.arange(num_steps))
332
- with torch.no_grad():
333
- for i in bar:
334
- inp = out_tensor * 1
335
- if len(out_tensor.shape) > 0 and out_tensor.shape[-1] > context_len:
336
- diff = inp.shape[-1] % 14 # height of mario level
337
- ctx = context_len + diff
338
- inp = inp[:, -ctx:] * 1
339
- next_tokens, encoder_hidden_states = self.step(
340
- inp,
341
- encoder_hidden_states=encoder_hidden_states,
342
- )
343
- out_tensor = torch.cat(
344
- [out_tensor, next_tokens.unsqueeze(-1)], dim=-1
345
- )
346
- if self.use_tqdm:
347
- bar.set_description(
348
- f"shape: {inp.shape}, {out_tensor.shape} first: {inp[0][0]}, last: {out_tensor[0][-1]}"
349
- )
350
- if self.use_tqdm:
351
- bar.close()
352
- sample_out = SampleOutput.from_level_predictions(
353
- out_tensor,
354
- out_tensor[:, -num_steps:],
355
- self.mario_lm.tokenizer,
356
- self.mario_lm.prompter,
357
- )
358
- self.mario_lm.train()
359
- if return_tensor:
360
- return sample_out, out_tensor
361
- return sample_out
362
-
363
- def __call__(self, *args, **kwargs):
364
- return self.sample(*args, **kwargs)
365
-
366
-
367
- class BertSampler:
368
- def __init__(
369
- self,
370
- mario_lm: BaseMarioLM,
371
- temperature: float = 2.0,
372
- top_k: int = 16,
373
- context_len: int = 448,
374
- mask_proportion: float = 0.16,
375
- ):
376
- self.mario_lm = mario_lm
377
- self.temperature = temperature
378
- self.top_k = top_k
379
- self.logits_processor = LogitsProcessorList()
380
- self.logits_warper = LogitsProcessorList(
381
- [
382
- TopKLogitsWarper(top_k), # number of characters
383
- TemperatureLogitsWarper(temperature),
384
- ]
385
- )
386
- self.context_len = context_len
387
- self.mask_proportion = mask_proportion
388
- self.mask_portion = int(self.context_len * self.mask_proportion)
389
- self.mask_portion = self.mask_portion - self.mask_portion % 14 + 14
390
-
391
- @property
392
- def device(self) -> torch.device:
393
- return self.mario_lm.device
394
-
395
- def get_context(self, input_ids, mask_indices):
396
- start_idx = mask_indices[0]
397
- end_idx = mask_indices[-1]
398
-
399
- if input_ids.shape[-1] <= self.context_len:
400
- clipped = input_ids.shape[-1] % 14
401
- input_ids = input_ids[:clipped]
402
-
403
- portion = (self.context_len - self.mask_portion) / 2
404
-
405
- remainder = 0
406
- left = start_idx - portion
407
- if left < 0:
408
- remainder = -1 * left
409
-
410
- right = end_idx + portion + remainder
411
-
412
- return input_ids[left:right]
413
-
414
- def sample(
415
- self,
416
- seed: Union[torch.Tensor, SampleOutput],
417
- mask: torch.Tensor,
418
- return_tensor: bool = False,
419
- ):
420
- self.mario_lm.eval()
421
- mask_indices = mask.nonzero()
422
- input_ids = seed
423
- if isinstance(seed, SampleOutput):
424
- input_ids = seed.level_tensor.to(self.device).squeeze()
425
-
426
- input_id_list = []
427
- for i in range(input_ids.shape[0]):
428
- input_id = input_ids[i]
429
- mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
430
- input_id = self.get_context(input_id, mask_index)
431
- input_id_list.append(input_id)
432
- input_ids = torch.stack(input_ids, dim=0).to(self.device)
433
-
434
- attention_mask = torch.ones_like(input_ids).to(seed.device)
435
-
436
- if len(input_ids.shape) < 2:
437
- # if we pass in a single seed vector, then we repeat for each prompt
438
- # Otherwise, we treat inputs as separate seed-prompt pairs
439
- input_ids = input_ids.view(1, -1)
440
-
441
- out = self.mario_lm.lm(
442
- input_ids=input_ids,
443
- attention_mask=attention_mask,
444
- token_type_ids=None,
445
- )
446
- logits = out.logits.detach()
447
- if len(logits.shape) == 2:
448
- logits = logits.view(1, 1, -1)
449
-
450
- if self.use_argmax:
451
- tokens = logits.argmax(-1)
452
- else:
453
- tokens_scores = self.logits_processor(input_ids, tokens)
454
- tokens_scores = self.logits_warper(input_ids, tokens_scores)
455
- probs = torch.nn.functional.softmax(tokens_scores, dim=-1)
456
- tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
457
-
458
- out = input_ids.detach()
459
-
460
- for i in range(input_ids.shape[0]):
461
- mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
462
- out[i, mask_index] = tokens[i, mask_index].detach()
463
-
464
- sample_out = SampleOutput.from_level_predictions(
465
- out,
466
- tokens,
467
- self.mario_lm.tokenizer,
468
- self.mario_lm.prompter,
469
- )
470
- self.mario_lm.train()
471
- if return_tensor:
472
- return sample_out, tokens
473
- return sample_out