File size: 8,588 Bytes
a09cfc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from interactive_generation import InteractiveGeneration
import torch
from level_dataset import visualize_samples, convert_to_level_format, positive_negative_caption_split
from caption_match import compare_captions, process_scene_segments
from create_ascii_captions import assign_caption
from util import extract_tileset
from sampler import scene_to_ascii
import argparse
import common_settings as common_settings
from sampler import SampleOutput
from pipeline_loader import get_pipeline


def parse_args():
    parser = argparse.ArgumentParser(description="Generate levels using a trained diffusion model")    
    # Model and generation parameters
    parser.add_argument("--model_path", type=str, required=True, help="Path to the trained diffusion model")
    parser.add_argument("--tileset", default='..\TheVGLC\Super Mario Bros\smb.json', help="Descriptions of individual tile types")
    #parser.add_argument("--describe_locations", action="store_true", default=False, help="Include location descriptions in the captions")
    parser.add_argument("--describe_absence", action="store_true", default=False, help="Indicate when there are no occurrences of an item or structure")
    parser.add_argument("--automatic_negative_captions", action="store_true", default=False, help="Automatically create negative captions for prompts so the user doesn't have to")


    parser.add_argument(
        "--game",
        type=str,
        default="Mario",
        choices=["Mario", "LR"],
        help="Which game to create a model for (affects sample style and tile count)"
    )

    return parser.parse_args()

class InteractiveLevelGeneration(InteractiveGeneration):
    def __init__(self, args):
        super().__init__(
            {
                "caption": str,
                "width": int,
                "negative_prompt": str,
                "start_seed": int,
                "end_seed": int,
                "num_inference_steps": int,
                "guidance_scale": float
            },
            default_parameters={
                "width":  width, #common_settings.MARIO_WIDTH,
                "start_seed": 1,
                "end_seed": 1,  # Will be set to start_seed if blank
                "num_inference_steps": common_settings.NUM_INFERENCE_STEPS,
                "guidance_scale": common_settings.GUIDANCE_SCALE,
                "caption": "",
                "negative_prompt": ""
            }
        )

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pipe = get_pipeline(args.model_path).to(self.device)
        self.pipe.print_unet_architecture()
        #self.pipe.save_unet_architecture_pdf(height, width)

        if args.automatic_negative_captions or not self.pipe.supports_negative_prompt:
            self.input_parameters.pop('negative_prompt', None)
            self.default_parameters.pop('negative_prompt', None)
        
        if args.automatic_negative_captions and not self.pipe.supports_negative_prompt:
            raise ValueError("Automatic negative caption generation is not possible with a model that doesn't support it")

        if args.tileset:
            _, self.id_to_char, self.char_to_id, self.tile_descriptors = extract_tileset(args.tileset)

        self.args = args

        if self.args.game == "LR":
            del self.input_parameters["width"]

        print(f"Tileset in use: {self.args.tileset}")

    def generate_image(self, param_values, generator, **extra_params):
        if self.args.automatic_negative_captions:
            pos, neg = positive_negative_caption_split(param_values["caption"], True)
            param_values["negative_prompt"] = neg
        images = self.pipe(
            generator=generator,
            **param_values
        ).images

        # Convert to indices
        sample_tensor = images[0].unsqueeze(0)
        sample_indices = convert_to_level_format(sample_tensor)

        # Add level data to the list
        scene = sample_indices[0].tolist()
        if self.args.game == "LR":
            number_of_tiles = common_settings.LR_TILE_COUNT
            scene = [[x % number_of_tiles for x in row] for row in scene]
 
        # Assign a caption to the sceneof whichever game is being played
        if self.args.game == "Mario":
            actual_caption = assign_caption(scene, self.id_to_char, self.char_to_id, self.tile_descriptors, False, self.args.describe_absence)
            level_width = common_settings.MARIO_WIDTH
        elif self.args.game == "LR":
            actual_caption = lr_assign_caption(scene, self.id_to_char, self.char_to_id, self.tile_descriptors, False, self.args.describe_absence)
            level_width = common_settings.LR_WIDTH
        else:
            raise ValueError(f"Unknown game: {self.args.game}")
        
        if args.game == "LR":
            print(f"Describe resulting image: {actual_caption}")
            lr_compare_score = lr_compare_captions(param_values.get("caption", ""), actual_caption)
            print(f"Comparison score: {lr_compare_score}")

            # Use the new function to process scene segments
            average_score, segment_captions, segment_scores = lr_process_scene_segments(
                scene=scene,
                segment_width=level_width,
                prompt=param_values.get("caption", ""),
                id_to_char=self.id_to_char,
                char_to_id=self.char_to_id,
                tile_descriptors=self.tile_descriptors,
                describe_locations=False, #self.args.describe_locations,
                describe_absence=self.args.describe_absence,
                verbose=True
            )

        elif args.game == "Mario":
            compare_score = compare_captions(param_values.get("caption", ""), actual_caption)
            print(f"Comparison score: {compare_score}")

            # Use the new function to process scene segments
            average_score, segment_captions, segment_scores = process_scene_segments(
                scene=scene,
                segment_width=level_width,
                prompt=param_values.get("caption", ""),
                id_to_char=self.id_to_char,
                char_to_id=self.char_to_id,
                tile_descriptors=self.tile_descriptors,
                describe_locations=False, #self.args.describe_locations,
                describe_absence=self.args.describe_absence,
                verbose=True
            )

            # Ask if user wants to play level
            play_level = input("Do you want to play this level? (y/n): ").strip().lower()
            if play_level == 'y':
                print("Playing level...")
                char_grid = scene_to_ascii(scene, self.id_to_char, False)
                level = SampleOutput(level=char_grid, use_snes_graphics=False)
                console_output = level.run_astar()
                print(console_output)
            elif play_level == 'n':
                print("Level not played.")
            else:
                raise ValueError(f"Unknown input: {play_level}")

        return visualize_samples(images)

    def get_extra_params(self, param_values): 
        if "negative_prompt" in param_values and param_values["negative_prompt"] == "":
            del param_values["negative_prompt"]

        if param_values["caption"] == "":
            del param_values["caption"]

        param_values["output_type"] = "tensor"

        # Lode Runner
        if self.args.game == "LR":
            param_values["height"] = common_settings.LR_HEIGHT
            param_values["width"] = common_settings.LR_WIDTH

        return dict()

if __name__ == "__main__":
    args = parse_args()

    if args.game == "Mario":
        args.num_tiles = common_settings.MARIO_TILE_COUNT
        height = common_settings.MARIO_HEIGHT
        width = common_settings.MARIO_WIDTH
        args.tile_size = common_settings.MARIO_TILE_PIXEL_DIM
        args.tileset = '..\TheVGLC\Super Mario Bros\smb.json'
    elif args.game == "LR":
        args.num_tiles = common_settings.LR_TILE_COUNT
        height = common_settings.LR_HEIGHT
        width = common_settings.LR_WIDTH
        args.tile_size = common_settings.LR_TILE_PIXEL_DIM
        args.tileset = '..\TheVGLC\Lode Runner\Loderunner.json'
    else:
        raise ValueError(f"Unknown game: {args.game}")
    
    ig = InteractiveLevelGeneration(args)
    ig.start()