| | from __future__ import annotations
|
| |
|
| | from dataclasses import dataclass
|
| | from typing import List, Optional, Tuple, Union
|
| |
|
| | import os
|
| | import subprocess
|
| | import tempfile
|
| |
|
| | import numpy as np
|
| | import torch
|
| | from PIL.Image import Image
|
| | from tqdm import tqdm
|
| | from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper
|
| |
|
| |
|
| | from mario_gpt.lm.base import BaseMarioLM
|
| | from mario_gpt.prompter import Prompter
|
| | from mario_gpt.simulator import Simulator
|
| | from mario_gpt.utils import (
|
| | convert_level_to_png,
|
| | load_level,
|
| | save_level,
|
| | trim_level,
|
| | view_level,
|
| | )
|
| |
|
| | def scene_to_ascii(scene, id_to_char, shorten: bool = True) -> List[str]:
|
| | """
|
| | Convert JSON scene files from a list of lists of ints
|
| | to a list of ASCII strings using id_to_char mapping.
|
| | If shorten is True, only the last 15 rows are kept.
|
| | Args:
|
| | scene: List[List[int]] - 2D array of tile IDs
|
| | id_to_char: Dict[int, str] - mapping from tile ID to ASCII character
|
| | shorten: bool - If True, will shorten the output to only include the first 15 rows
|
| | so A* Mario (for SNES graphics) to run without glitching
|
| | Returns:
|
| | List[str]: List of strings, each representing a row in ASCII
|
| | """
|
| | if shorten and len(scene) > 15:
|
| | scene = scene[-15:]
|
| | return ["".join(id_to_char[num] for num in row) for row in scene]
|
| |
|
| | @dataclass
|
| | class SampleOutput:
|
| | level: Optional[List[str]]
|
| | prompt: Optional[str] = None
|
| | img: Optional[Image] = None
|
| | sample_predictions_str: Optional[List[str]] = None
|
| | sample_predictions_img: Optional[Image] = None
|
| | level_tensor: Optional[torch.Tensor] = None
|
| | sample_predictions_tensor: Optional[torch.Tensor] = None
|
| |
|
| | use_snes_graphics: bool = False
|
| |
|
| | @classmethod
|
| | def create(
|
| | cls,
|
| | level_tensor: torch.Tensor,
|
| | sample_predictions_tensor: torch.Tensor,
|
| | tokenizer,
|
| | prompter: Optional[Prompter] = None,
|
| | ) -> SampleOutput:
|
| |
|
| | level = None
|
| | img = None
|
| |
|
| | try:
|
| | level = view_level(level_tensor, tokenizer)
|
| | img = convert_level_to_png(level)[0]
|
| | except Exception as e:
|
| | print(
|
| | f"Failed to generate string or image representation for full level! Got error {e}"
|
| | )
|
| | level = None
|
| | img = None
|
| | try:
|
| | sample_predictions_str = view_level(sample_predictions_tensor, tokenizer)
|
| | sample_predictions_img = convert_level_to_png(sample_predictions_str)[0]
|
| | except Exception as e:
|
| | print(
|
| | f"Failed to generate string or image representation for sampled predictions! Got error {e}"
|
| | )
|
| | sample_predictions_str = None
|
| | sample_predictions_img = None
|
| |
|
| | prompt = None
|
| | if prompter is not None:
|
| | prompt = prompter(level_tensor)[0]
|
| |
|
| | return SampleOutput(
|
| | level,
|
| | prompt,
|
| | img,
|
| | sample_predictions_str,
|
| | sample_predictions_img,
|
| | level_tensor,
|
| | sample_predictions_tensor,
|
| | )
|
| |
|
| | @classmethod
|
| | def from_level_predictions(
|
| | cls,
|
| | level: torch.Tensor,
|
| | sample_predictions: torch.Tensor,
|
| | tokenizer,
|
| | prompter: Optional[Prompter] = None,
|
| | ) -> Union[SampleOutput, List[SampleOutput]]:
|
| | level_tensor = trim_level(level).squeeze().detach().cpu()
|
| | sample_predictions_tensor = (
|
| | trim_level(sample_predictions).squeeze().detach().cpu()
|
| | )
|
| |
|
| | if len(level_tensor.shape) == 1:
|
| | return SampleOutput.create(
|
| | level_tensor, sample_predictions_tensor, tokenizer, prompter
|
| | )
|
| |
|
| | out = []
|
| | for _level_tensor, _sample_predictions_tensor in zip(
|
| | level_tensor, sample_predictions_tensor
|
| | ):
|
| | sample_output = SampleOutput.create(
|
| | _level_tensor, _sample_predictions_tensor, tokenizer, prompter
|
| | )
|
| | out.append(sample_output)
|
| | return out
|
| |
|
| | def save(self, filename: str) -> str:
|
| | save_level(self.level, filename)
|
| |
|
| | @classmethod
|
| | def load(cls, filename: str) -> SampleOutput:
|
| | level = load_level(filename)
|
| | return SampleOutput(level=level)
|
| |
|
| | def play(self, game="mario", level_idx=None, dataset_path=None):
|
| | """
|
| | Play the level using the specified game engine.
|
| | game: "mario" (default) or "loderunner"
|
| | """
|
| | if game == "loderunner":
|
| | import tempfile, json
|
| |
|
| | scene = [[c for c in row] for row in self.level]
|
| | lr_json = [{
|
| | "scene": scene,
|
| | "caption": ""
|
| | }]
|
| | with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp:
|
| | json.dump(lr_json, tmp)
|
| | tmp_path = tmp.name
|
| | import sys, os
|
| |
|
| | from LodeRunner.loderunner import main
|
| | tmp_path = tmp_path if dataset_path is None else dataset_path
|
| | print(f"Playing Lode Runner level interactively -- {tmp_path}!")
|
| | main.play_lr_level(tmp_path, level_index=level_idx if level_idx is not None else 1)
|
| | else:
|
| | if self.use_snes_graphics:
|
| | simulator = CustomSimulator(level=self.level, jar_path="MarioEval.jar")
|
| | else:
|
| | simulator = CustomSimulator(level=self.level, jar_path="NESMarioEval.jar")
|
| | simulator.interactive()
|
| |
|
| | def run_astar(self, render=True):
|
| | if self.use_snes_graphics:
|
| | simulator = CustomSimulator(level=self.level, jar_path="MarioEval.jar")
|
| | else:
|
| | simulator = CustomSimulator(level=self.level, jar_path="NESMarioEval.jar")
|
| | return simulator.astar(render)
|
| |
|
| | class CustomSimulator:
|
| | """
|
| | The classic Mario simulator used by MarioGPT is generally,
|
| | better, but it doesn't return any information about
|
| | Mario's performance. The main point of this simulator
|
| | is that information about the performance of the agent
|
| | is printed to the console (though I still need a way
|
| | to caption and return that information)
|
| | """
|
| |
|
| | def __init__(self, level, jar_path="MarioEval.jar"):
|
| | while len(level) > 15:
|
| | level.pop(0)
|
| |
|
| |
|
| |
|
| | self.level = level
|
| | self.jar_path = jar_path
|
| |
|
| | def interactive(self):
|
| | t = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
|
| | save_level(self.level, t.name)
|
| | print(f"Playing level interactively -- {t.name}!")
|
| | _ = subprocess.run(
|
| | ["java", "-jar", self.jar_path, "human", t.name, "human"],
|
| | stdout=subprocess.PIPE,
|
| | stderr=subprocess.PIPE,
|
| | )
|
| | t.close()
|
| | os.unlink(t.name)
|
| |
|
| | def astar(self, render: bool = True):
|
| | t = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
|
| | save_level(self.level, t.name)
|
| | print(f"Running Astar agent on level! -- {t.name}")
|
| | render_str = "human" if render else "norender"
|
| | result = subprocess.run(
|
| | ["java", "-jar", self.jar_path, "astar", t.name, render_str],
|
| | stdout=subprocess.PIPE,
|
| | stderr=subprocess.PIPE,
|
| | )
|
| | t.close()
|
| | os.unlink(t.name)
|
| |
|
| | output = result.stdout.decode("utf-8") + result.stderr.decode("utf-8")
|
| | return output
|
| |
|
| | def save_level(level: List[str], filename: str):
|
| | concatenated = "\n".join(level)
|
| | with open(filename, "w") as f:
|
| | f.write(concatenated)
|
| | return filename
|
| |
|
| | class GPTSampler:
|
| | def __init__(
|
| | self,
|
| | mario_lm: BaseMarioLM,
|
| | temperature: float = 2.0,
|
| | top_k: int = 16,
|
| | context_len: int = 700,
|
| | use_tqdm: bool = False,
|
| | use_argmax: bool = False,
|
| | ):
|
| | self.mario_lm = mario_lm
|
| | self.temperature = temperature
|
| | self.top_k = top_k
|
| | self.context_len = context_len
|
| | self.use_tqdm = use_tqdm
|
| | self.use_argmax = use_argmax
|
| | self.logits_processor = LogitsProcessorList()
|
| | self.logits_warper = LogitsProcessorList(
|
| | [
|
| | TopKLogitsWarper(top_k),
|
| | TemperatureLogitsWarper(temperature),
|
| | ]
|
| | )
|
| |
|
| | @property
|
| | def device(self) -> torch.device:
|
| | return self.mario_lm.device
|
| |
|
| | def step(
|
| | self,
|
| | seed: torch.Tensor,
|
| | encoder_hidden_states: torch.Tensor,
|
| | ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| | with torch.no_grad():
|
| | attention_mask = torch.ones_like(seed).to(seed.device)
|
| | input_ids = seed
|
| | out = self.mario_lm.lm(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | encoder_hidden_states=encoder_hidden_states,
|
| | token_type_ids=None,
|
| | )
|
| | logits = out.logits.detach()
|
| | if len(logits.shape) == 2:
|
| | logits = logits.view(1, 1, -1)
|
| | next_token_logits = logits[:, -1, :]
|
| |
|
| | if self.use_argmax:
|
| | next_tokens = next_token_logits.argmax(-1)
|
| | else:
|
| | next_token_scores = self.logits_processor(input_ids, next_token_logits)
|
| | next_token_scores = self.logits_warper(input_ids, next_token_scores)
|
| | probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
| | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| | return next_tokens, encoder_hidden_states
|
| |
|
| | def sample(
|
| | self,
|
| | seed: Union[Optional[torch.Tensor], Optional[SampleOutput]] = None,
|
| | prompts: Optional[List[str]] = None,
|
| | num_steps: int = 1,
|
| | encoder_hidden_states: torch.Tensor = None,
|
| | return_tensor: bool = False,
|
| | ):
|
| | self.mario_lm.eval()
|
| | context_len = self.context_len - 28
|
| | with torch.no_grad():
|
| | if seed is None:
|
| | seed = self.mario_lm.generate_seed(1, batch_size=len(prompts)).to(
|
| | self.device
|
| | )
|
| | out_tensor = seed.to(self.device)
|
| | elif isinstance(seed, SampleOutput):
|
| | out_tensor = seed.level_tensor.to(self.device).squeeze()
|
| | else:
|
| | out_tensor = seed.to(self.device).squeeze()
|
| | if len(out_tensor.shape) < 2:
|
| |
|
| |
|
| | out_tensor = out_tensor.view(1, -1).repeat(len(prompts), 1)
|
| | if encoder_hidden_states is None:
|
| | if prompts is not None:
|
| | encoder_hidden_states = torch.stack(
|
| | [
|
| | self.mario_lm.prompter.output_hidden(prompt)
|
| | for prompt in prompts
|
| | ]
|
| | )
|
| | else:
|
| | encoder_hidden_states = torch.stack(
|
| | [
|
| | self.mario_lm.prompter(sample_prompt=True)[1]
|
| | for _ in range(seed.shape[0])
|
| | ]
|
| | )
|
| | encoder_hidden_states = encoder_hidden_states.to(
|
| | self.device
|
| | )
|
| | encoder_hidden_states = encoder_hidden_states.view(
|
| | out_tensor.shape[0], 1, -1
|
| | )
|
| | if not self.use_tqdm:
|
| | bar = np.arange(num_steps)
|
| | else:
|
| | bar = tqdm(np.arange(num_steps))
|
| | with torch.no_grad():
|
| | for i in bar:
|
| | inp = out_tensor * 1
|
| | if len(out_tensor.shape) > 0 and out_tensor.shape[-1] > context_len:
|
| | diff = inp.shape[-1] % 14
|
| | ctx = context_len + diff
|
| | inp = inp[:, -ctx:] * 1
|
| | next_tokens, encoder_hidden_states = self.step(
|
| | inp,
|
| | encoder_hidden_states=encoder_hidden_states,
|
| | )
|
| | out_tensor = torch.cat(
|
| | [out_tensor, next_tokens.unsqueeze(-1)], dim=-1
|
| | )
|
| | if self.use_tqdm:
|
| | bar.set_description(
|
| | f"shape: {inp.shape}, {out_tensor.shape} first: {inp[0][0]}, last: {out_tensor[0][-1]}"
|
| | )
|
| | if self.use_tqdm:
|
| | bar.close()
|
| | sample_out = SampleOutput.from_level_predictions(
|
| | out_tensor,
|
| | out_tensor[:, -num_steps:],
|
| | self.mario_lm.tokenizer,
|
| | self.mario_lm.prompter,
|
| | )
|
| | self.mario_lm.train()
|
| | if return_tensor:
|
| | return sample_out, out_tensor
|
| | return sample_out
|
| |
|
| | def __call__(self, *args, **kwargs):
|
| | return self.sample(*args, **kwargs)
|
| |
|
| |
|
| | class BertSampler:
|
| | def __init__(
|
| | self,
|
| | mario_lm: BaseMarioLM,
|
| | temperature: float = 2.0,
|
| | top_k: int = 16,
|
| | context_len: int = 448,
|
| | mask_proportion: float = 0.16,
|
| | ):
|
| | self.mario_lm = mario_lm
|
| | self.temperature = temperature
|
| | self.top_k = top_k
|
| | self.logits_processor = LogitsProcessorList()
|
| | self.logits_warper = LogitsProcessorList(
|
| | [
|
| | TopKLogitsWarper(top_k),
|
| | TemperatureLogitsWarper(temperature),
|
| | ]
|
| | )
|
| | self.context_len = context_len
|
| | self.mask_proportion = mask_proportion
|
| | self.mask_portion = int(self.context_len * self.mask_proportion)
|
| | self.mask_portion = self.mask_portion - self.mask_portion % 14 + 14
|
| |
|
| | @property
|
| | def device(self) -> torch.device:
|
| | return self.mario_lm.device
|
| |
|
| | def get_context(self, input_ids, mask_indices):
|
| | start_idx = mask_indices[0]
|
| | end_idx = mask_indices[-1]
|
| |
|
| | if input_ids.shape[-1] <= self.context_len:
|
| | clipped = input_ids.shape[-1] % 14
|
| | input_ids = input_ids[:clipped]
|
| |
|
| | portion = (self.context_len - self.mask_portion) / 2
|
| |
|
| | remainder = 0
|
| | left = start_idx - portion
|
| | if left < 0:
|
| | remainder = -1 * left
|
| |
|
| | right = end_idx + portion + remainder
|
| |
|
| | return input_ids[left:right]
|
| |
|
| | def sample(
|
| | self,
|
| | seed: Union[torch.Tensor, SampleOutput],
|
| | mask: torch.Tensor,
|
| | return_tensor: bool = False,
|
| | ):
|
| | self.mario_lm.eval()
|
| | mask_indices = mask.nonzero()
|
| | input_ids = seed
|
| | if isinstance(seed, SampleOutput):
|
| | input_ids = seed.level_tensor.to(self.device).squeeze()
|
| |
|
| | input_id_list = []
|
| | for i in range(input_ids.shape[0]):
|
| | input_id = input_ids[i]
|
| | mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
|
| | input_id = self.get_context(input_id, mask_index)
|
| | input_id_list.append(input_id)
|
| | input_ids = torch.stack(input_ids, dim=0).to(self.device)
|
| |
|
| | attention_mask = torch.ones_like(input_ids).to(seed.device)
|
| |
|
| | if len(input_ids.shape) < 2:
|
| |
|
| |
|
| | input_ids = input_ids.view(1, -1)
|
| |
|
| | out = self.mario_lm.lm(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | token_type_ids=None,
|
| | )
|
| | logits = out.logits.detach()
|
| | if len(logits.shape) == 2:
|
| | logits = logits.view(1, 1, -1)
|
| |
|
| | if self.use_argmax:
|
| | tokens = logits.argmax(-1)
|
| | else:
|
| | tokens_scores = self.logits_processor(input_ids, tokens)
|
| | tokens_scores = self.logits_warper(input_ids, tokens_scores)
|
| | probs = torch.nn.functional.softmax(tokens_scores, dim=-1)
|
| | tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| |
|
| | out = input_ids.detach()
|
| |
|
| | for i in range(input_ids.shape[0]):
|
| | mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
|
| | out[i, mask_index] = tokens[i, mask_index].detach()
|
| |
|
| | sample_out = SampleOutput.from_level_predictions(
|
| | out,
|
| | tokens,
|
| | self.mario_lm.tokenizer,
|
| | self.mario_lm.prompter,
|
| | )
|
| | self.mario_lm.train()
|
| | if return_tensor:
|
| | return sample_out, tokens
|
| | return sample_out
|
| |
|