File size: 8,588 Bytes
a09cfc1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | from interactive_generation import InteractiveGeneration
import torch
from level_dataset import visualize_samples, convert_to_level_format, positive_negative_caption_split
from caption_match import compare_captions, process_scene_segments
from create_ascii_captions import assign_caption
from util import extract_tileset
from sampler import scene_to_ascii
import argparse
import common_settings as common_settings
from sampler import SampleOutput
from pipeline_loader import get_pipeline
def parse_args():
parser = argparse.ArgumentParser(description="Generate levels using a trained diffusion model")
# Model and generation parameters
parser.add_argument("--model_path", type=str, required=True, help="Path to the trained diffusion model")
parser.add_argument("--tileset", default='..\TheVGLC\Super Mario Bros\smb.json', help="Descriptions of individual tile types")
#parser.add_argument("--describe_locations", action="store_true", default=False, help="Include location descriptions in the captions")
parser.add_argument("--describe_absence", action="store_true", default=False, help="Indicate when there are no occurrences of an item or structure")
parser.add_argument("--automatic_negative_captions", action="store_true", default=False, help="Automatically create negative captions for prompts so the user doesn't have to")
parser.add_argument(
"--game",
type=str,
default="Mario",
choices=["Mario", "LR"],
help="Which game to create a model for (affects sample style and tile count)"
)
return parser.parse_args()
class InteractiveLevelGeneration(InteractiveGeneration):
def __init__(self, args):
super().__init__(
{
"caption": str,
"width": int,
"negative_prompt": str,
"start_seed": int,
"end_seed": int,
"num_inference_steps": int,
"guidance_scale": float
},
default_parameters={
"width": width, #common_settings.MARIO_WIDTH,
"start_seed": 1,
"end_seed": 1, # Will be set to start_seed if blank
"num_inference_steps": common_settings.NUM_INFERENCE_STEPS,
"guidance_scale": common_settings.GUIDANCE_SCALE,
"caption": "",
"negative_prompt": ""
}
)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.pipe = get_pipeline(args.model_path).to(self.device)
self.pipe.print_unet_architecture()
#self.pipe.save_unet_architecture_pdf(height, width)
if args.automatic_negative_captions or not self.pipe.supports_negative_prompt:
self.input_parameters.pop('negative_prompt', None)
self.default_parameters.pop('negative_prompt', None)
if args.automatic_negative_captions and not self.pipe.supports_negative_prompt:
raise ValueError("Automatic negative caption generation is not possible with a model that doesn't support it")
if args.tileset:
_, self.id_to_char, self.char_to_id, self.tile_descriptors = extract_tileset(args.tileset)
self.args = args
if self.args.game == "LR":
del self.input_parameters["width"]
print(f"Tileset in use: {self.args.tileset}")
def generate_image(self, param_values, generator, **extra_params):
if self.args.automatic_negative_captions:
pos, neg = positive_negative_caption_split(param_values["caption"], True)
param_values["negative_prompt"] = neg
images = self.pipe(
generator=generator,
**param_values
).images
# Convert to indices
sample_tensor = images[0].unsqueeze(0)
sample_indices = convert_to_level_format(sample_tensor)
# Add level data to the list
scene = sample_indices[0].tolist()
if self.args.game == "LR":
number_of_tiles = common_settings.LR_TILE_COUNT
scene = [[x % number_of_tiles for x in row] for row in scene]
# Assign a caption to the sceneof whichever game is being played
if self.args.game == "Mario":
actual_caption = assign_caption(scene, self.id_to_char, self.char_to_id, self.tile_descriptors, False, self.args.describe_absence)
level_width = common_settings.MARIO_WIDTH
elif self.args.game == "LR":
actual_caption = lr_assign_caption(scene, self.id_to_char, self.char_to_id, self.tile_descriptors, False, self.args.describe_absence)
level_width = common_settings.LR_WIDTH
else:
raise ValueError(f"Unknown game: {self.args.game}")
if args.game == "LR":
print(f"Describe resulting image: {actual_caption}")
lr_compare_score = lr_compare_captions(param_values.get("caption", ""), actual_caption)
print(f"Comparison score: {lr_compare_score}")
# Use the new function to process scene segments
average_score, segment_captions, segment_scores = lr_process_scene_segments(
scene=scene,
segment_width=level_width,
prompt=param_values.get("caption", ""),
id_to_char=self.id_to_char,
char_to_id=self.char_to_id,
tile_descriptors=self.tile_descriptors,
describe_locations=False, #self.args.describe_locations,
describe_absence=self.args.describe_absence,
verbose=True
)
elif args.game == "Mario":
compare_score = compare_captions(param_values.get("caption", ""), actual_caption)
print(f"Comparison score: {compare_score}")
# Use the new function to process scene segments
average_score, segment_captions, segment_scores = process_scene_segments(
scene=scene,
segment_width=level_width,
prompt=param_values.get("caption", ""),
id_to_char=self.id_to_char,
char_to_id=self.char_to_id,
tile_descriptors=self.tile_descriptors,
describe_locations=False, #self.args.describe_locations,
describe_absence=self.args.describe_absence,
verbose=True
)
# Ask if user wants to play level
play_level = input("Do you want to play this level? (y/n): ").strip().lower()
if play_level == 'y':
print("Playing level...")
char_grid = scene_to_ascii(scene, self.id_to_char, False)
level = SampleOutput(level=char_grid, use_snes_graphics=False)
console_output = level.run_astar()
print(console_output)
elif play_level == 'n':
print("Level not played.")
else:
raise ValueError(f"Unknown input: {play_level}")
return visualize_samples(images)
def get_extra_params(self, param_values):
if "negative_prompt" in param_values and param_values["negative_prompt"] == "":
del param_values["negative_prompt"]
if param_values["caption"] == "":
del param_values["caption"]
param_values["output_type"] = "tensor"
# Lode Runner
if self.args.game == "LR":
param_values["height"] = common_settings.LR_HEIGHT
param_values["width"] = common_settings.LR_WIDTH
return dict()
if __name__ == "__main__":
args = parse_args()
if args.game == "Mario":
args.num_tiles = common_settings.MARIO_TILE_COUNT
height = common_settings.MARIO_HEIGHT
width = common_settings.MARIO_WIDTH
args.tile_size = common_settings.MARIO_TILE_PIXEL_DIM
args.tileset = '..\TheVGLC\Super Mario Bros\smb.json'
elif args.game == "LR":
args.num_tiles = common_settings.LR_TILE_COUNT
height = common_settings.LR_HEIGHT
width = common_settings.LR_WIDTH
args.tile_size = common_settings.LR_TILE_PIXEL_DIM
args.tileset = '..\TheVGLC\Lode Runner\Loderunner.json'
else:
raise ValueError(f"Unknown game: {args.game}")
ig = InteractiveLevelGeneration(args)
ig.start()
|