| from typing import List |
|
|
| from dataclasses import dataclass |
| from .fast_sd import fast_diffusion_pipeline |
|
|
| import torch |
| import pygame |
| import numpy as np |
| import time |
| from PIL import Image |
|
|
| from .game_objects import Point, TextPrompt |
| from .sampling import ( |
| DistanceSampling, |
| CircleSampling |
| ) |
|
|
| @dataclass |
| class GameConfig: |
| point_thickness : float = 10 |
| zoom_speed : float = 0.75 |
| move_speed : float = 0.75 |
| point_font_size : int = 25 |
|
|
| prompt_font_size : int = 30 |
|
|
| |
| width : int = 1920 |
| height : int = 1080 |
|
|
| |
| sample_width : int = 512 |
| sample_height : int = 512 |
|
|
| compile : bool = False |
| sampler : str = "distance" |
| seed : int = 0 |
| call_every : int = 90 |
|
|
| class LatentSpaceExplorer: |
| def __init__(self, config : GameConfig = GameConfig()): |
| self.config = config |
|
|
| self.pipe = fast_diffusion_pipeline(compile = self.config.compile) |
| self.points : List[Point] = [] |
| self.player_pos = None |
|
|
| self.dragging_point_idx = None |
| self.selected_point_idx = None |
|
|
| self.zoom_level = 300.0 |
| self.translation = np.array([-self.config.width/2, -self.config.height/2]) |
|
|
| self.point_kwargs = {} |
| if self.config.sampler == "distance": |
| self.sampler = DistanceSampling |
| elif self.config.sampler == "circle": |
| self.sampler = CircleSampling |
| self.point_kwargs['on_edge'] = True |
| else: |
| raise ValueError(f"Invalid sampler choice: {self.config.sampler}") |
|
|
| pygame.init() |
| self.screen = pygame.display.set_mode((self.config.width, self.config.height)) |
| self.clock = pygame.time.Clock() |
| self.ms_elapsed = 0 |
|
|
| |
| self.avg_latency = (0, 0) |
| |
| self.sample_image = None |
| self.sample_font = pygame.font.Font(None, self.config.point_font_size) |
| |
| |
| self.input_font = pygame.font.Font(None, self.config.prompt_font_size) |
| self.inputting_text = False |
| self.inputting_text_for = None |
| self.text_prompt : TextPrompt = None |
| |
| def tick(self): |
| self.clock.tick() |
| self.ms_elapsed += self.clock.get_time() |
|
|
| def update_latency(self, new_observation): |
| n = self.avg_latency[0] |
| old_avg = self.avg_latency[1] |
| self.avg_latency = (n + 1, (old_avg * n + new_observation) / (n + 1)) |
|
|
| def create_text_prompt(self, prompt_text): |
| self.text_prompt = TextPrompt(prompt_text, self.input_font, self.screen) |
|
|
| def switch_sampler(self): |
| if self.config.sampler == "distance": |
| self.config.sampler = "circle" |
| self.sampler = CircleSampling |
| self.point_kwargs = {'on_edge' : True} |
| elif self.config.sampler == "circle": |
| self.config.sampler = "distance" |
| self.sampler = DistanceSampling |
| self.point_kwargs = {} |
| self.set_prompts(self.prompts, reset = True) |
| |
| @property |
| def encodes(self): |
| """ |
| Get encodings directly from points as a tuple with batched encodings |
| """ |
| if not self.points: |
| return None |
| encode_list = [p.encoding for p in self.points] |
| n = len(encode_list[0]) |
| res = [] |
| for i in range(n): |
| res.append(torch.cat([e[i] for e in encode_list], dim = 0) if encode_list[0][i] is not None else None) |
|
|
| return tuple(res) |
| |
| @property |
| def prompts(self): |
| """ |
| Get a list of current prompts |
| """ |
| return [p.text for p in self.points] |
| |
| @property |
| def r2_points(self): |
| """ |
| Get all points in terms of R2 space |
| """ |
| points = [np.array(p.xy_pos) for p in self.points] |
| points = np.stack(points, axis = 0) |
| return points |
|
|
| @property |
| def screen_space_points(self): |
| """ |
| Get all points in terms of screen space |
| """ |
| screen_space = (self.r2_points * self.zoom_level) - self.translation[None,:] |
| return screen_space |
| |
| @property |
| def mouse_pos(self): |
| return np.array(pygame.mouse.get_pos()) |
|
|
| def invert_screen_space(self, point): |
| """ |
| taking position as [2,] np array in screen space, return R2 pos |
| """ |
| return (point + self.translation) / self.zoom_level |
|
|
| def screen_space(self, point): |
| """ |
| R2 -> screenspace as [2,] array |
| """ |
| return (point * self.zoom_level) - self.translation |
| |
| def fixed_seed(self): |
| """ |
| Controls random number generator for initial latent noise |
| """ |
| return torch.Generator('cuda').manual_seed(self.config.seed) |
| |
| def get_encodes(self, text): |
| """ |
| Get text encodings for some prompt then split them so we can associate points with thier encodings |
| """ |
| encodes = self.pipe.get_encodes(text, generator = self.fixed_seed()) |
| |
| if not isinstance(encodes, tuple) and not isinstance(encodes, list): |
| return encodes |
| |
| res_list = [] |
| for i in range(len(encodes[0])): |
| res_list_i = [encodes_j[i].unsqueeze(0) if encodes_j is not None else None for encodes_j in encodes] |
| res_list.append(tuple(res_list_i)) |
|
|
| return res_list |
| |
| def draw_sample(self): |
| """ |
| Draw sample with current points and player position |
| """ |
| if self.player_pos is not None and self.encodes is not None: |
| if self.ms_elapsed >= self.config.call_every: |
| time_start = time.time() |
| encoding = self.sampler(self.encodes)(self.player_pos, self.r2_points) |
| self.sample_image = self.pipe.generate_from_encodes(encoding, generator = self.fixed_seed()).images[0] |
| time_total = float(time.time() - time_start) * 1000 |
|
|
| self.update_latency(time_total) |
|
|
| self.ms_elapsed = 0 |
|
|
| def get_player_pos_r2(self): |
| """ |
| Get player position in R2 from the |
| """ |
| self.player_pos = self.invert_screen_space(self.mouse_pos) |
|
|
| def get_player_pos_screenspace(self): |
| """ |
| Get player pos in screen space |
| """ |
| if self.player_pos is not None: return self.screen_space(self.player_pos) |
| |
| def detect_mouse_on_point(self): |
| """ |
| Detect if mouse is currently in a point. If so, returns index of point, otherwise returns none. |
| """ |
| if not self.points: |
| return None |
| |
| mouse_pos = self.mouse_pos |
| points = self.screen_space_points |
|
|
| distances = np.linalg.norm(points - mouse_pos[None,:], axis = 1) |
| close_idx = np.argmin(distances) |
|
|
| if distances[close_idx] <= self.config.point_thickness: |
| return close_idx |
| return None |
|
|
| |
|
|
| def modify_node(self, new_prompt): |
| idx = self.selected_point_idx |
| new_prompts = self.prompts |
| new_prompts[idx] = new_prompt |
| self.set_prompts(new_prompts, reset = False) |
| |
| def add_node(self, new_prompt): |
| self.set_prompts(self.prompts + [new_prompt], reset = False) |
|
|
| def del_node(self): |
| idx = self.selected_point_idx |
| new_prompts = list(self.prompts) |
| del new_prompts[idx] |
| self.set_prompts(new_prompts, reset = False) |
| self.selected_point_idx = None |
|
|
| def prepare_to_prompt(self, mode): |
| """ |
| Get ready to show the textbox. Call when we want the text prompt to come |
| """ |
| self.inputting_text = True |
| self.inputting_text_for = mode |
|
|
| if mode == "modify": |
| self.create_text_prompt("Enter New Prompt To Replace Node:") |
| elif mode == "add": |
| self.create_text_prompt("Enter New Prompt To Create Node:") |
| |
| def handle_prompt(self): |
| """ |
| After enter pressed with textbox, this is called to go back to normal game |
| """ |
| done_prompting = self.text_prompt.update() |
|
|
| if done_prompting: |
| new_prompt = self.text_prompt.user_input.strip() |
| if self.inputting_text_for == "modify": |
| self.modify_node(new_prompt) |
| elif self.inputting_text_for == "add": |
| self.add_node(new_prompt) |
| self.text_prompt = None |
| self.inputting_text = False |
|
|
| def set_prompts(self, prompts : List[str], reset : bool = False): |
| """ |
| :param prompts: New prompts to update to |
| :param reset: Reset xy positions of points? |
| """ |
|
|
| if len(prompts) > 0: |
| encodes = self.get_encodes(prompts) |
|
|
| |
| if not self.points or reset: |
| self.points = [Point(prompt, encoding, xy_init_kwargs = self.point_kwargs) for (prompt, encoding) in zip(prompts, encodes)] |
| return |
| |
| |
| old_len = len(self.points) |
| new_len = len(prompts) |
| |
| pos = [tuple(pos_i) for pos_i in self.r2_points] |
|
|
| if old_len <= new_len: |
| pos += [None] * (new_len - old_len) |
| self.points = [Point(prompt, encoding, pos_i, xy_init_kwargs = self.point_kwargs) for (prompt, encoding, pos_i) in zip(prompts, encodes, pos)] |
| return |
| elif old_len > new_len: |
| idx_to_keep = [] |
| for idx, prompt in enumerate(self.prompts): |
| if prompt in prompts: |
| idx_to_keep.append(idx) |
| self.points = [self.points[idx] for idx in idx_to_keep] |
| return |
|
|
| |
|
|
| def handle_event_controls(self): |
| """ |
| Handles discrete (i.e. keydown, mousedown) controls through events |
| """ |
| for event in pygame.event.get(): |
| if event.type == pygame.QUIT: |
| pygame.quit() |
| quit() |
| elif event.type == pygame.MOUSEBUTTONDOWN: |
| |
| if event.button == 1: |
| self.selected_point_idx = self.detect_mouse_on_point() |
| if self.selected_point_idx is not None: self.dragging_point_idx = None |
| else: |
| self.get_player_pos_r2() |
| self.draw_sample() |
| elif event.button == 3: |
| self.dragging_point_idx = self.detect_mouse_on_point() |
| if self.dragging_point_idx is not None: self.selected_point_idx = None |
| elif event.type == pygame.MOUSEBUTTONUP: |
| if event.button == 3: |
| self.dragging_point_idx = None |
| elif event.type == pygame.MOUSEMOTION: |
| if self.dragging_point_idx is not None: |
| |
| self.points[self.dragging_point_idx].move(self.invert_screen_space(self.mouse_pos)) |
| elif pygame.mouse.get_pressed()[0]: |
| self.get_player_pos_r2() |
| self.draw_sample() |
| elif event.type == pygame.KEYDOWN: |
| keys = pygame.key.get_pressed() |
| if keys[pygame.K_r]: |
| self.set_prompts(self.prompts, reset = True) |
| elif keys[pygame.K_t] and self.selected_point_idx is not None: |
| self.prepare_to_prompt("modify") |
| return |
| elif keys[pygame.K_p]: |
| self.prepare_to_prompt("add") |
| return |
| elif keys[pygame.K_o] and self.selected_point_idx is not None: |
| |
| self.del_node() |
| elif keys[pygame.K_g]: |
| if self.sample_image is not None: |
| self.sample_image.save("sample.png") |
| elif keys[pygame.K_m]: |
| |
| self.switch_sampler() |
|
|
|
|
| def handle_continuous_controls(self): |
| """ |
| Continuous controls for movement (i.e. zoom, movement) |
| """ |
| keys = pygame.key.get_pressed() |
| if keys[pygame.K_q]: |
| self.zoom_level = max(0.01, self.zoom_level - self.config.zoom_speed) |
| elif keys[pygame.K_e]: |
| self.zoom_level = self.zoom_level + self.config.zoom_speed |
|
|
| idx, sign = None, None |
|
|
| |
| if keys[pygame.K_w]: |
| idx, sign = 1, -1 |
| elif keys[pygame.K_s]: |
| idx, sign = 1, 1 |
| elif keys[pygame.K_a]: |
| idx, sign = 0, -1 |
| elif keys[pygame.K_d]: |
| idx, sign = 0, 1 |
|
|
| if idx is not None and sign is not None: |
| self.translation[idx] += sign * self.config.move_speed |
|
|
| |
|
|
| def draw_main_screen(self): |
| """ |
| Draw main screen. Sample image, points, etc. |
| """ |
| def get_point_color(idx): |
| color = (255, 255, 255) |
| if idx == self.selected_point_idx: |
| color = (0, 127.5, 0) |
| if idx == self.dragging_point_idx: |
| color = (255, 0, 0) |
| return color |
| |
| if self.config.sampler == "circle": |
| |
| center = np.array([0,0]) |
| border = np.array([1,0]) |
|
|
| center = self.screen_space(center) |
| border = self.screen_space(border) |
| radius = abs(border[0] - center[0]) |
|
|
| pygame.draw.circle(self.screen, (255, 255, 255), center, int(radius), 1) |
| |
| if len(self.points) > 0: |
| for idx, point in enumerate(self.screen_space_points): |
| pygame.draw.circle(self.screen, get_point_color(idx), point, self.config.point_thickness) |
| text = self.sample_font.render(self.points[idx].text, True, get_point_color(idx)) |
| self.screen.blit(text, point) |
| |
| player_pos = self.get_player_pos_screenspace() |
| if player_pos is not None: |
| pygame.draw.circle(self.screen, (0, 255, 0), player_pos, self.config.point_thickness/2) |
| |
| if self.sample_image is not None: |
| pygame_image = pygame.image.fromstring(self.sample_image.tobytes(), self.sample_image.size, self.sample_image.mode) |
| pygame_image = pygame.transform.scale(pygame_image, (self.config.sample_width, self.config.sample_height)) |
| self.screen.blit(pygame_image, (0, 0)) |
|
|
| def update(self): |
| """ |
| Main pygame loop |
| """ |
|
|
| if not self.inputting_text: |
| self.handle_event_controls() |
| self.handle_continuous_controls() |
| self.tick() |
| |
| self.screen.fill((0,0,0)) |
| self.draw_main_screen() |
|
|
| |
| if self.inputting_text: |
| self.handle_prompt() |
| pygame.display.flip() |
|
|