schrum2 commited on
Commit
8628a45
·
verified ·
1 Parent(s): 1df3d59

Delete sampler.py

Browse files
Files changed (1) hide show
  1. sampler.py +0 -473
sampler.py DELETED
@@ -1,473 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from typing import List, Optional, Tuple, Union
5
-
6
- import os
7
- import subprocess
8
- import tempfile
9
-
10
- import numpy as np
11
- import torch
12
- from PIL.Image import Image
13
- from tqdm import tqdm
14
- from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper
15
-
16
-
17
- from mario_gpt.lm.base import BaseMarioLM
18
- from mario_gpt.prompter import Prompter
19
- from mario_gpt.simulator import Simulator
20
- from mario_gpt.utils import (
21
- convert_level_to_png,
22
- load_level,
23
- save_level,
24
- trim_level,
25
- view_level,
26
- )
27
-
28
- def scene_to_ascii(scene, id_to_char, shorten: bool = True) -> List[str]:
29
- """
30
- Convert JSON scene files from a list of lists of ints
31
- to a list of ASCII strings using id_to_char mapping.
32
- If shorten is True, only the last 15 rows are kept.
33
- Args:
34
- scene: List[List[int]] - 2D array of tile IDs
35
- id_to_char: Dict[int, str] - mapping from tile ID to ASCII character
36
- shorten: bool - If True, will shorten the output to only include the first 15 rows
37
- so A* Mario (for SNES graphics) to run without glitching
38
- Returns:
39
- List[str]: List of strings, each representing a row in ASCII
40
- """
41
- if shorten and len(scene) > 15:
42
- scene = scene[-15:] # Keep only the last 15 rows
43
- return ["".join(id_to_char[num] for num in row) for row in scene]
44
-
45
- @dataclass
46
- class SampleOutput:
47
- level: Optional[List[str]]
48
- prompt: Optional[str] = None
49
- img: Optional[Image] = None
50
- sample_predictions_str: Optional[List[str]] = None
51
- sample_predictions_img: Optional[Image] = None
52
- level_tensor: Optional[torch.Tensor] = None
53
- sample_predictions_tensor: Optional[torch.Tensor] = None
54
- # Uses MarioEval graphics for rendering levels when True
55
- use_snes_graphics: bool = False
56
-
57
- @classmethod
58
- def create(
59
- cls,
60
- level_tensor: torch.Tensor,
61
- sample_predictions_tensor: torch.Tensor,
62
- tokenizer,
63
- prompter: Optional[Prompter] = None,
64
- ) -> SampleOutput:
65
- # batch = 1
66
- level = None
67
- img = None
68
-
69
- try:
70
- level = view_level(level_tensor, tokenizer)
71
- img = convert_level_to_png(level)[0]
72
- except Exception as e:
73
- print(
74
- f"Failed to generate string or image representation for full level! Got error {e}"
75
- )
76
- level = None
77
- img = None
78
- try:
79
- sample_predictions_str = view_level(sample_predictions_tensor, tokenizer)
80
- sample_predictions_img = convert_level_to_png(sample_predictions_str)[0]
81
- except Exception as e:
82
- print(
83
- f"Failed to generate string or image representation for sampled predictions! Got error {e}"
84
- )
85
- sample_predictions_str = None
86
- sample_predictions_img = None
87
-
88
- prompt = None
89
- if prompter is not None:
90
- prompt = prompter(level_tensor)[0]
91
-
92
- return SampleOutput(
93
- level,
94
- prompt,
95
- img,
96
- sample_predictions_str,
97
- sample_predictions_img,
98
- level_tensor,
99
- sample_predictions_tensor,
100
- )
101
-
102
- @classmethod
103
- def from_level_predictions(
104
- cls,
105
- level: torch.Tensor,
106
- sample_predictions: torch.Tensor,
107
- tokenizer,
108
- prompter: Optional[Prompter] = None,
109
- ) -> Union[SampleOutput, List[SampleOutput]]:
110
- level_tensor = trim_level(level).squeeze().detach().cpu()
111
- sample_predictions_tensor = (
112
- trim_level(sample_predictions).squeeze().detach().cpu()
113
- )
114
-
115
- if len(level_tensor.shape) == 1:
116
- return SampleOutput.create(
117
- level_tensor, sample_predictions_tensor, tokenizer, prompter
118
- )
119
-
120
- out = []
121
- for _level_tensor, _sample_predictions_tensor in zip(
122
- level_tensor, sample_predictions_tensor
123
- ):
124
- sample_output = SampleOutput.create(
125
- _level_tensor, _sample_predictions_tensor, tokenizer, prompter
126
- )
127
- out.append(sample_output)
128
- return out
129
-
130
- def save(self, filename: str) -> str:
131
- save_level(self.level, filename)
132
-
133
- @classmethod
134
- def load(cls, filename: str) -> SampleOutput:
135
- level = load_level(filename)
136
- return SampleOutput(level=level)
137
-
138
- def play(self, game="mario", level_idx=None, dataset_path=None):
139
- """
140
- Play the level using the specified game engine.
141
- game: "mario" (default) or "loderunner"
142
- """
143
- if game == "loderunner":
144
- import tempfile, json
145
- # Convert self.level (list of strings) to Lode Runner JSON format
146
- scene = [[c for c in row] for row in self.level]
147
- lr_json = [{
148
- "scene": scene,
149
- "caption": ""
150
- }]
151
- with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
152
- json.dump(lr_json, tmp)
153
- tmp_path = tmp.name
154
- import sys, os
155
- #sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
156
- from LodeRunner.loderunner import main
157
- tmp_path = tmp_path if dataset_path is None else dataset_path
158
- print(f"Playing Lode Runner level interactively -- {tmp_path}!")
159
- main.play_lr_level(tmp_path, level_index=level_idx if level_idx is not None else 1)
160
- else:
161
- if self.use_snes_graphics:
162
- simulator = CustomSimulator(level=self.level, jar_path="MarioEval.jar")
163
- else:
164
- simulator = CustomSimulator(level=self.level, jar_path="NESMarioEval.jar")
165
- simulator.interactive()
166
-
167
- def run_astar(self, render=True):
168
- if self.use_snes_graphics:
169
- simulator = CustomSimulator(level=self.level, jar_path="MarioEval.jar")
170
- else:
171
- simulator = CustomSimulator(level=self.level, jar_path="NESMarioEval.jar")
172
- return simulator.astar(render)
173
-
174
- class CustomSimulator:
175
- """
176
- The classic Mario simulator used by MarioGPT is generally,
177
- better, but it doesn't return any information about
178
- Mario's performance. The main point of this simulator
179
- is that information about the performance of the agent
180
- is printed to the console (though I still need a way
181
- to caption and return that information)
182
- """
183
-
184
- def __init__(self, level, jar_path="MarioEval.jar"):
185
- while len(level) > 15:
186
- level.pop(0)
187
- # For some reason, my older A* agent
188
- # crashes on Mario levels with 16 rows or more
189
-
190
- self.level = level
191
- self.jar_path = jar_path
192
-
193
- def interactive(self):
194
- t = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
195
- save_level(self.level, t.name)
196
- print(f"Playing level interactively -- {t.name}!")
197
- _ = subprocess.run(
198
- ["java", "-jar", self.jar_path, "human", t.name, "human"],
199
- stdout=subprocess.PIPE,
200
- stderr=subprocess.PIPE,
201
- )
202
- t.close()
203
- os.unlink(t.name)
204
-
205
- def astar(self, render: bool = True):
206
- t = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
207
- save_level(self.level, t.name)
208
- print(f"Running Astar agent on level! -- {t.name}")
209
- render_str = "human" if render else "norender"
210
- result = subprocess.run(
211
- ["java", "-jar", self.jar_path, "astar", t.name, render_str],
212
- stdout=subprocess.PIPE,
213
- stderr=subprocess.PIPE,
214
- )
215
- t.close()
216
- os.unlink(t.name)
217
- # Combine stdout and stderr, decode to string, and return
218
- output = result.stdout.decode("utf-8") + result.stderr.decode("utf-8")
219
- return output
220
-
221
- def save_level(level: List[str], filename: str):
222
- concatenated = "\n".join(level)
223
- with open(filename, "w") as f:
224
- f.write(concatenated)
225
- return filename
226
-
227
- class GPTSampler:
228
- def __init__(
229
- self,
230
- mario_lm: BaseMarioLM,
231
- temperature: float = 2.0,
232
- top_k: int = 16,
233
- context_len: int = 700,
234
- use_tqdm: bool = False,
235
- use_argmax: bool = False,
236
- ):
237
- self.mario_lm = mario_lm
238
- self.temperature = temperature
239
- self.top_k = top_k
240
- self.context_len = context_len
241
- self.use_tqdm = use_tqdm
242
- self.use_argmax = use_argmax
243
- self.logits_processor = LogitsProcessorList()
244
- self.logits_warper = LogitsProcessorList(
245
- [
246
- TopKLogitsWarper(top_k), # number of characters
247
- TemperatureLogitsWarper(temperature),
248
- ]
249
- )
250
-
251
- @property
252
- def device(self) -> torch.device:
253
- return self.mario_lm.device
254
-
255
- def step(
256
- self,
257
- seed: torch.Tensor,
258
- encoder_hidden_states: torch.Tensor,
259
- ) -> Tuple[torch.Tensor, torch.Tensor]:
260
- with torch.no_grad():
261
- attention_mask = torch.ones_like(seed).to(seed.device)
262
- input_ids = seed
263
- out = self.mario_lm.lm(
264
- input_ids=input_ids,
265
- attention_mask=attention_mask,
266
- encoder_hidden_states=encoder_hidden_states,
267
- token_type_ids=None,
268
- )
269
- logits = out.logits.detach()
270
- if len(logits.shape) == 2:
271
- logits = logits.view(1, 1, -1)
272
- next_token_logits = logits[:, -1, :]
273
-
274
- if self.use_argmax:
275
- next_tokens = next_token_logits.argmax(-1)
276
- else:
277
- next_token_scores = self.logits_processor(input_ids, next_token_logits)
278
- next_token_scores = self.logits_warper(input_ids, next_token_scores)
279
- probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
280
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
281
- return next_tokens, encoder_hidden_states
282
-
283
- def sample(
284
- self,
285
- seed: Union[Optional[torch.Tensor], Optional[SampleOutput]] = None,
286
- prompts: Optional[List[str]] = None,
287
- num_steps: int = 1,
288
- encoder_hidden_states: torch.Tensor = None,
289
- return_tensor: bool = False,
290
- ):
291
- self.mario_lm.eval()
292
- context_len = self.context_len - 28
293
- with torch.no_grad():
294
- if seed is None:
295
- seed = self.mario_lm.generate_seed(1, batch_size=len(prompts)).to(
296
- self.device
297
- )
298
- out_tensor = seed.to(self.device)
299
- elif isinstance(seed, SampleOutput):
300
- out_tensor = seed.level_tensor.to(self.device).squeeze()
301
- else:
302
- out_tensor = seed.to(self.device).squeeze()
303
- if len(out_tensor.shape) < 2:
304
- # if we pass in a single seed vector, then we repeat for each prompt
305
- # Otherwise, we treat inputs as separate seed-prompt pairs
306
- out_tensor = out_tensor.view(1, -1).repeat(len(prompts), 1)
307
- if encoder_hidden_states is None:
308
- if prompts is not None:
309
- encoder_hidden_states = torch.stack(
310
- [
311
- self.mario_lm.prompter.output_hidden(prompt)
312
- for prompt in prompts
313
- ]
314
- )
315
- else:
316
- encoder_hidden_states = torch.stack(
317
- [
318
- self.mario_lm.prompter(sample_prompt=True)[1]
319
- for _ in range(seed.shape[0])
320
- ]
321
- )
322
- encoder_hidden_states = encoder_hidden_states.to(
323
- self.device
324
- ) # b x 1 x hidden_dim
325
- encoder_hidden_states = encoder_hidden_states.view(
326
- out_tensor.shape[0], 1, -1
327
- )
328
- if not self.use_tqdm:
329
- bar = np.arange(num_steps)
330
- else:
331
- bar = tqdm(np.arange(num_steps))
332
- with torch.no_grad():
333
- for i in bar:
334
- inp = out_tensor * 1
335
- if len(out_tensor.shape) > 0 and out_tensor.shape[-1] > context_len:
336
- diff = inp.shape[-1] % 14 # height of mario level
337
- ctx = context_len + diff
338
- inp = inp[:, -ctx:] * 1
339
- next_tokens, encoder_hidden_states = self.step(
340
- inp,
341
- encoder_hidden_states=encoder_hidden_states,
342
- )
343
- out_tensor = torch.cat(
344
- [out_tensor, next_tokens.unsqueeze(-1)], dim=-1
345
- )
346
- if self.use_tqdm:
347
- bar.set_description(
348
- f"shape: {inp.shape}, {out_tensor.shape} first: {inp[0][0]}, last: {out_tensor[0][-1]}"
349
- )
350
- if self.use_tqdm:
351
- bar.close()
352
- sample_out = SampleOutput.from_level_predictions(
353
- out_tensor,
354
- out_tensor[:, -num_steps:],
355
- self.mario_lm.tokenizer,
356
- self.mario_lm.prompter,
357
- )
358
- self.mario_lm.train()
359
- if return_tensor:
360
- return sample_out, out_tensor
361
- return sample_out
362
-
363
- def __call__(self, *args, **kwargs):
364
- return self.sample(*args, **kwargs)
365
-
366
-
367
- class BertSampler:
368
- def __init__(
369
- self,
370
- mario_lm: BaseMarioLM,
371
- temperature: float = 2.0,
372
- top_k: int = 16,
373
- context_len: int = 448,
374
- mask_proportion: float = 0.16,
375
- ):
376
- self.mario_lm = mario_lm
377
- self.temperature = temperature
378
- self.top_k = top_k
379
- self.logits_processor = LogitsProcessorList()
380
- self.logits_warper = LogitsProcessorList(
381
- [
382
- TopKLogitsWarper(top_k), # number of characters
383
- TemperatureLogitsWarper(temperature),
384
- ]
385
- )
386
- self.context_len = context_len
387
- self.mask_proportion = mask_proportion
388
- self.mask_portion = int(self.context_len * self.mask_proportion)
389
- self.mask_portion = self.mask_portion - self.mask_portion % 14 + 14
390
-
391
- @property
392
- def device(self) -> torch.device:
393
- return self.mario_lm.device
394
-
395
- def get_context(self, input_ids, mask_indices):
396
- start_idx = mask_indices[0]
397
- end_idx = mask_indices[-1]
398
-
399
- if input_ids.shape[-1] <= self.context_len:
400
- clipped = input_ids.shape[-1] % 14
401
- input_ids = input_ids[:clipped]
402
-
403
- portion = (self.context_len - self.mask_portion) / 2
404
-
405
- remainder = 0
406
- left = start_idx - portion
407
- if left < 0:
408
- remainder = -1 * left
409
-
410
- right = end_idx + portion + remainder
411
-
412
- return input_ids[left:right]
413
-
414
- def sample(
415
- self,
416
- seed: Union[torch.Tensor, SampleOutput],
417
- mask: torch.Tensor,
418
- return_tensor: bool = False,
419
- ):
420
- self.mario_lm.eval()
421
- mask_indices = mask.nonzero()
422
- input_ids = seed
423
- if isinstance(seed, SampleOutput):
424
- input_ids = seed.level_tensor.to(self.device).squeeze()
425
-
426
- input_id_list = []
427
- for i in range(input_ids.shape[0]):
428
- input_id = input_ids[i]
429
- mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
430
- input_id = self.get_context(input_id, mask_index)
431
- input_id_list.append(input_id)
432
- input_ids = torch.stack(input_ids, dim=0).to(self.device)
433
-
434
- attention_mask = torch.ones_like(input_ids).to(seed.device)
435
-
436
- if len(input_ids.shape) < 2:
437
- # if we pass in a single seed vector, then we repeat for each prompt
438
- # Otherwise, we treat inputs as separate seed-prompt pairs
439
- input_ids = input_ids.view(1, -1)
440
-
441
- out = self.mario_lm.lm(
442
- input_ids=input_ids,
443
- attention_mask=attention_mask,
444
- token_type_ids=None,
445
- )
446
- logits = out.logits.detach()
447
- if len(logits.shape) == 2:
448
- logits = logits.view(1, 1, -1)
449
-
450
- if self.use_argmax:
451
- tokens = logits.argmax(-1)
452
- else:
453
- tokens_scores = self.logits_processor(input_ids, tokens)
454
- tokens_scores = self.logits_warper(input_ids, tokens_scores)
455
- probs = torch.nn.functional.softmax(tokens_scores, dim=-1)
456
- tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
457
-
458
- out = input_ids.detach()
459
-
460
- for i in range(input_ids.shape[0]):
461
- mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
462
- out[i, mask_index] = tokens[i, mask_index].detach()
463
-
464
- sample_out = SampleOutput.from_level_predictions(
465
- out,
466
- tokens,
467
- self.mario_lm.tokenizer,
468
- self.mario_lm.prompter,
469
- )
470
- self.mario_lm.train()
471
- if return_tensor:
472
- return sample_out, tokens
473
- return sample_out