schrum2 commited on
Commit
1df3d59
·
verified ·
1 Parent(s): b5baa0f

Don't think I need this

Browse files
Files changed (1) hide show
  1. text_to_level_diffusion.py +0 -194
text_to_level_diffusion.py DELETED
@@ -1,194 +0,0 @@
1
- from interactive_generation import InteractiveGeneration
2
- import torch
3
- from level_dataset import visualize_samples, convert_to_level_format, positive_negative_caption_split
4
- from caption_match import compare_captions, process_scene_segments
5
- from create_ascii_captions import assign_caption
6
- from util import extract_tileset
7
- from sampler import scene_to_ascii
8
- import argparse
9
- import common_settings as common_settings
10
- from sampler import SampleOutput
11
- from pipeline_loader import get_pipeline
12
-
13
-
14
- def parse_args():
15
- parser = argparse.ArgumentParser(description="Generate levels using a trained diffusion model")
16
- # Model and generation parameters
17
- parser.add_argument("--model_path", type=str, required=True, help="Path to the trained diffusion model")
18
- parser.add_argument("--tileset", default='..\TheVGLC\Super Mario Bros\smb.json', help="Descriptions of individual tile types")
19
- #parser.add_argument("--describe_locations", action="store_true", default=False, help="Include location descriptions in the captions")
20
- parser.add_argument("--describe_absence", action="store_true", default=False, help="Indicate when there are no occurrences of an item or structure")
21
- parser.add_argument("--automatic_negative_captions", action="store_true", default=False, help="Automatically create negative captions for prompts so the user doesn't have to")
22
-
23
-
24
- parser.add_argument(
25
- "--game",
26
- type=str,
27
- default="Mario",
28
- choices=["Mario", "LR"],
29
- help="Which game to create a model for (affects sample style and tile count)"
30
- )
31
-
32
- return parser.parse_args()
33
-
34
- class InteractiveLevelGeneration(InteractiveGeneration):
35
- def __init__(self, args):
36
- super().__init__(
37
- {
38
- "caption": str,
39
- "width": int,
40
- "negative_prompt": str,
41
- "start_seed": int,
42
- "end_seed": int,
43
- "num_inference_steps": int,
44
- "guidance_scale": float
45
- },
46
- default_parameters={
47
- "width": width, #common_settings.MARIO_WIDTH,
48
- "start_seed": 1,
49
- "end_seed": 1, # Will be set to start_seed if blank
50
- "num_inference_steps": common_settings.NUM_INFERENCE_STEPS,
51
- "guidance_scale": common_settings.GUIDANCE_SCALE,
52
- "caption": "",
53
- "negative_prompt": ""
54
- }
55
- )
56
-
57
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
- self.pipe = get_pipeline(args.model_path).to(self.device)
59
- self.pipe.print_unet_architecture()
60
- #self.pipe.save_unet_architecture_pdf(height, width)
61
-
62
- if args.automatic_negative_captions or not self.pipe.supports_negative_prompt:
63
- self.input_parameters.pop('negative_prompt', None)
64
- self.default_parameters.pop('negative_prompt', None)
65
-
66
- if args.automatic_negative_captions and not self.pipe.supports_negative_prompt:
67
- raise ValueError("Automatic negative caption generation is not possible with a model that doesn't support it")
68
-
69
- if args.tileset:
70
- _, self.id_to_char, self.char_to_id, self.tile_descriptors = extract_tileset(args.tileset)
71
-
72
- self.args = args
73
-
74
- if self.args.game == "LR":
75
- del self.input_parameters["width"]
76
-
77
- print(f"Tileset in use: {self.args.tileset}")
78
-
79
- def generate_image(self, param_values, generator, **extra_params):
80
- if self.args.automatic_negative_captions:
81
- pos, neg = positive_negative_caption_split(param_values["caption"], True)
82
- param_values["negative_prompt"] = neg
83
- images = self.pipe(
84
- generator=generator,
85
- **param_values
86
- ).images
87
-
88
- # Convert to indices
89
- sample_tensor = images[0].unsqueeze(0)
90
- sample_indices = convert_to_level_format(sample_tensor)
91
-
92
- # Add level data to the list
93
- scene = sample_indices[0].tolist()
94
- if self.args.game == "LR":
95
- number_of_tiles = common_settings.LR_TILE_COUNT
96
- scene = [[x % number_of_tiles for x in row] for row in scene]
97
-
98
- # Assign a caption to the sceneof whichever game is being played
99
- if self.args.game == "Mario":
100
- actual_caption = assign_caption(scene, self.id_to_char, self.char_to_id, self.tile_descriptors, False, self.args.describe_absence)
101
- level_width = common_settings.MARIO_WIDTH
102
- elif self.args.game == "LR":
103
- actual_caption = lr_assign_caption(scene, self.id_to_char, self.char_to_id, self.tile_descriptors, False, self.args.describe_absence)
104
- level_width = common_settings.LR_WIDTH
105
- else:
106
- raise ValueError(f"Unknown game: {self.args.game}")
107
-
108
- if args.game == "LR":
109
- print(f"Describe resulting image: {actual_caption}")
110
- lr_compare_score = lr_compare_captions(param_values.get("caption", ""), actual_caption)
111
- print(f"Comparison score: {lr_compare_score}")
112
-
113
- # Use the new function to process scene segments
114
- average_score, segment_captions, segment_scores = lr_process_scene_segments(
115
- scene=scene,
116
- segment_width=level_width,
117
- prompt=param_values.get("caption", ""),
118
- id_to_char=self.id_to_char,
119
- char_to_id=self.char_to_id,
120
- tile_descriptors=self.tile_descriptors,
121
- describe_locations=False, #self.args.describe_locations,
122
- describe_absence=self.args.describe_absence,
123
- verbose=True
124
- )
125
-
126
- elif args.game == "Mario":
127
- compare_score = compare_captions(param_values.get("caption", ""), actual_caption)
128
- print(f"Comparison score: {compare_score}")
129
-
130
- # Use the new function to process scene segments
131
- average_score, segment_captions, segment_scores = process_scene_segments(
132
- scene=scene,
133
- segment_width=level_width,
134
- prompt=param_values.get("caption", ""),
135
- id_to_char=self.id_to_char,
136
- char_to_id=self.char_to_id,
137
- tile_descriptors=self.tile_descriptors,
138
- describe_locations=False, #self.args.describe_locations,
139
- describe_absence=self.args.describe_absence,
140
- verbose=True
141
- )
142
-
143
- # Ask if user wants to play level
144
- play_level = input("Do you want to play this level? (y/n): ").strip().lower()
145
- if play_level == 'y':
146
- print("Playing level...")
147
- char_grid = scene_to_ascii(scene, self.id_to_char, False)
148
- level = SampleOutput(level=char_grid, use_snes_graphics=False)
149
- console_output = level.run_astar()
150
- print(console_output)
151
- elif play_level == 'n':
152
- print("Level not played.")
153
- else:
154
- raise ValueError(f"Unknown input: {play_level}")
155
-
156
- return visualize_samples(images)
157
-
158
- def get_extra_params(self, param_values):
159
- if "negative_prompt" in param_values and param_values["negative_prompt"] == "":
160
- del param_values["negative_prompt"]
161
-
162
- if param_values["caption"] == "":
163
- del param_values["caption"]
164
-
165
- param_values["output_type"] = "tensor"
166
-
167
- # Lode Runner
168
- if self.args.game == "LR":
169
- param_values["height"] = common_settings.LR_HEIGHT
170
- param_values["width"] = common_settings.LR_WIDTH
171
-
172
- return dict()
173
-
174
- if __name__ == "__main__":
175
- args = parse_args()
176
-
177
- if args.game == "Mario":
178
- args.num_tiles = common_settings.MARIO_TILE_COUNT
179
- height = common_settings.MARIO_HEIGHT
180
- width = common_settings.MARIO_WIDTH
181
- args.tile_size = common_settings.MARIO_TILE_PIXEL_DIM
182
- args.tileset = '..\TheVGLC\Super Mario Bros\smb.json'
183
- elif args.game == "LR":
184
- args.num_tiles = common_settings.LR_TILE_COUNT
185
- height = common_settings.LR_HEIGHT
186
- width = common_settings.LR_WIDTH
187
- args.tile_size = common_settings.LR_TILE_PIXEL_DIM
188
- args.tileset = '..\TheVGLC\Lode Runner\Loderunner.json'
189
- else:
190
- raise ValueError(f"Unknown game: {args.game}")
191
-
192
- ig = InteractiveLevelGeneration(args)
193
- ig.start()
194
-