Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import random | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from scipy import stats | |
| from transformers import pipeline | |
| from mario_gpt.dataset import MarioDataset | |
| from mario_gpt.utils import view_level | |
| STATISTICS = { | |
| "enemy": np.array([1.0, 3.0, 7.0]), | |
| "pipe": np.array([0.0, 2.0, 5.0]), | |
| "block": np.array([50.0, 75.0, 176.0]), | |
| } | |
| FEATURE_EXTRACTION_MODEL = "facebook/bart-base" | |
| class Prompter: | |
| def __init__( | |
| self, | |
| level_tokenizer, | |
| prompter_model: str = FEATURE_EXTRACTION_MODEL, | |
| use_raw_counts: bool = False, | |
| statistics: Optional[Dict[str, Any]] = None, | |
| ): | |
| self.prompter_model = prompter_model | |
| self.feature_extraction = pipeline( | |
| "feature-extraction", | |
| model=prompter_model, | |
| tokenizer=prompter_model, | |
| framework="pt", | |
| ) | |
| self.level_tokenizer = level_tokenizer | |
| self.use_raw_counts = use_raw_counts | |
| self.statistics = statistics | |
| if statistics is None: | |
| self.statistics = STATISTICS | |
| def pipe_thresholds(self) -> Tuple[List[int], List[str]]: | |
| thresholds = self.statistics["pipe"] | |
| keywords = ["no", "little", "some", "many"] | |
| return thresholds, keywords | |
| def enemy_thresholds(self) -> Tuple[List[int], List[str]]: | |
| thresholds = self.statistics["enemy"] | |
| keywords = ["no", "little", "some", "many"] | |
| return thresholds, keywords | |
| def block_thresholds(self) -> Tuple[List[int], List[str]]: | |
| thresholds = self.statistics["block"] | |
| keywords = ["little", "little", "some", "many"] | |
| return thresholds, keywords | |
| def count_pipes(self, flattened_level: str) -> int: | |
| return flattened_level.count("<>") | |
| def count_enemies(self, flattened_level: str) -> int: | |
| return flattened_level.count("E") + flattened_level.count("B") | |
| def count_blocks(self, flattened_level: str) -> int: | |
| return np.sum([flattened_level.count(char) for char in ["X", "S", "?", "Q"]]) | |
| def _flatten_level(self, string_level: List[str]) -> str: | |
| return "".join(string_level) | |
| def pipe_prompt(self, flattened_level: str, level: str) -> str: | |
| count = self.count_pipes(flattened_level) | |
| keyword = f"{count}" | |
| if not self.use_raw_counts: | |
| thresholds, keywords = self.pipe_thresholds | |
| threshold = np.digitize(count, thresholds, right=True) | |
| keyword = keywords[threshold] | |
| return f"{keyword} pipes", keyword | |
| def enemy_prompt(self, flattened_level: str, level: str) -> str: | |
| count = self.count_enemies(flattened_level) | |
| keyword = f"{count}" | |
| if not self.use_raw_counts: | |
| thresholds, keywords = self.enemy_thresholds | |
| threshold = np.digitize(count, thresholds, right=True) | |
| keyword = keywords[threshold] | |
| return f"{keyword} enemies", keyword | |
| def block_prompt(self, flattened_level: str, level: str) -> str: | |
| count = self.count_blocks(flattened_level) | |
| keyword = f"{count}" | |
| if not self.use_raw_counts: | |
| thresholds, keywords = self.block_thresholds | |
| threshold = np.digitize(count, thresholds, right=True) | |
| keyword = keywords[threshold] | |
| return f"{keyword} blocks", keyword | |
| def elevation_prompt(self, flattened_level: str, level: str): | |
| top_levels = level[:6] # elevation 8 and up | |
| for t in top_levels: | |
| if "X" in t or "<" in t or ">" in t: | |
| return "high elevation", "high" | |
| return "low elevation", "low" | |
| def output_hidden(self, prompt: str, device: torch.device = torch.device("cpu")): | |
| # Reducing along the first dimension to get a 768 dimensional array | |
| return ( | |
| self.feature_extraction(prompt, return_tensors="pt")[0] | |
| .mean(0) | |
| .to(device) | |
| .view(1, -1) | |
| ) | |
| def dataset_statistics(self, dataset: MarioDataset): | |
| enemy_counts = [] | |
| pipe_counts = [] | |
| block_counts = [] | |
| for i in range(len(dataset)): | |
| level, _ = dataset[i] | |
| str_level = self._flatten_level(view_level(level, dataset.tokenizer)) | |
| enemy_count = self.count_enemies(str_level) | |
| pipe_count = self.count_pipes(str_level) | |
| block_count = self.count_blocks(str_level) | |
| enemy_counts.append(enemy_count) | |
| pipe_counts.append(pipe_count) | |
| block_counts.append(block_count) | |
| d = {"enemy": {}, "pipe": {}, "block": {}} | |
| d["enemy"] = stats.mstats.mquantiles(enemy_counts, [0.33, 0.66, 0.95]) | |
| d["pipe"] = stats.mstats.mquantiles(pipe_counts, [0.33, 0.66, 0.95]) | |
| d["block"] = stats.mstats.mquantiles(block_counts, [0.33, 0.66, 0.95]) | |
| return d | |
| def __call__( | |
| self, level: torch.Tensor = None, sample_prompt: bool = False | |
| ) -> Union[str, torch.Tensor]: | |
| device: torch.device = torch.device("cpu") | |
| if not sample_prompt: | |
| if level is None: | |
| raise ValueError("Level must be provided if sample_prompt is not true!") | |
| str_level = view_level(level, self.level_tokenizer) | |
| flattened_level = self._flatten_level(str_level) | |
| pipe_prompt, _ = self.pipe_prompt(flattened_level, str_level) | |
| enemy_prompt, _ = self.enemy_prompt(flattened_level, str_level) | |
| block_prompt, _ = self.block_prompt(flattened_level, str_level) | |
| elevation_prompt, _ = self.elevation_prompt(flattened_level, str_level) | |
| device = level.device | |
| else: | |
| str_level = None | |
| pipe_prompt = random.choice(["no", "little", "some", "many"]) + " pipes" | |
| enemy_prompt = random.choice(["no", "little", "some", "many"]) + " enemies" | |
| block_prompt = ( | |
| random.choice(["little", "little", "some", "many"]) + " blocks" | |
| ) # levels always have blocks | |
| elevation_prompt = ( | |
| random.choice(["low", "high"]) + " elevation" | |
| ) # levels always have blocks | |
| prompt_dict = { | |
| "pipe": pipe_prompt, | |
| "enemy": enemy_prompt, | |
| "block": block_prompt, | |
| "elevation_prompt": elevation_prompt, | |
| } | |
| prompt = f"{pipe_prompt}, {enemy_prompt}, {block_prompt}, {elevation_prompt}" | |
| hidden = self.output_hidden(prompt, device=device) | |
| return prompt, hidden, prompt_dict, str_level | |