| import os |
| import torch |
| import numpy as np |
| import imgui |
| import dnnlib |
| from gui_utils import imgui_utils |
|
|
| |
|
|
| class DragWidget: |
| def __init__(self, viz): |
| self.viz = viz |
| self.point = [-1, -1] |
| self.points = [] |
| self.targets = [] |
| self.is_point = True |
| self.last_click = False |
| self.is_drag = False |
| self.iteration = 0 |
| self.mode = 'point' |
| self.r_mask = 50 |
| self.show_mask = False |
| self.mask = torch.ones(256, 256) |
| self.lambda_mask = 20 |
| self.feature_idx = 5 |
| self.r1 = 3 |
| self.r2 = 12 |
| self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots')) |
| self.defer_frames = 0 |
| self.disabled_time = 0 |
|
|
| def action(self, click, down, x, y): |
| if self.mode == 'point': |
| self.add_point(click, x, y) |
| elif down: |
| self.draw_mask(x, y) |
|
|
| def add_point(self, click, x, y): |
| if click: |
| self.point = [y, x] |
| elif self.last_click: |
| if self.is_drag: |
| self.stop_drag() |
| if self.is_point: |
| self.points.append(self.point) |
| self.is_point = False |
| else: |
| self.targets.append(self.point) |
| self.is_point = True |
| self.last_click = click |
|
|
| def init_mask(self, w, h): |
| self.width, self.height = w, h |
| self.mask = torch.ones(h, w) |
|
|
| def draw_mask(self, x, y): |
| X = torch.linspace(0, self.width, self.width) |
| Y = torch.linspace(0, self.height, self.height) |
| yy, xx = torch.meshgrid(Y, X) |
| circle = (xx - x)**2 + (yy - y)**2 < self.r_mask**2 |
| if self.mode == 'flexible': |
| self.mask[circle] = 0 |
| elif self.mode == 'fixed': |
| self.mask[circle] = 1 |
|
|
| def stop_drag(self): |
| self.is_drag = False |
| self.iteration = 0 |
|
|
| def set_points(self, points): |
| self.points = points |
|
|
| def reset_point(self): |
| self.points = [] |
| self.targets = [] |
| self.is_point = True |
|
|
| def load_points(self, suffix): |
| points = [] |
| point_path = self.path + f'_{suffix}.txt' |
| try: |
| with open(point_path, "r") as f: |
| for line in f.readlines(): |
| y, x = line.split() |
| points.append([int(y), int(x)]) |
| except: |
| print(f'Wrong point file path: {point_path}') |
| return points |
|
|
| @imgui_utils.scoped_by_object_id |
| def __call__(self, show=True): |
| viz = self.viz |
| reset = False |
| if show: |
| with imgui_utils.grayed_out(self.disabled_time != 0): |
| imgui.text('Drag') |
| imgui.same_line(viz.label_w) |
|
|
| if imgui_utils.button('Add point', width=viz.button_w, enabled='image' in viz.result): |
| self.mode = 'point' |
|
|
| imgui.same_line() |
| reset = False |
| if imgui_utils.button('Reset point', width=viz.button_w, enabled='image' in viz.result): |
| self.reset_point() |
| reset = True |
|
|
| imgui.text(' ') |
| imgui.same_line(viz.label_w) |
| if imgui_utils.button('Start', width=viz.button_w, enabled='image' in viz.result): |
| self.is_drag = True |
| if len(self.points) > len(self.targets): |
| self.points = self.points[:len(self.targets)] |
|
|
| imgui.same_line() |
| if imgui_utils.button('Stop', width=viz.button_w, enabled='image' in viz.result): |
| self.stop_drag() |
|
|
| imgui.text(' ') |
| imgui.same_line(viz.label_w) |
| imgui.text(f'Steps: {self.iteration}') |
| |
| imgui.text('Mask') |
| imgui.same_line(viz.label_w) |
| if imgui_utils.button('Flexible area', width=viz.button_w, enabled='image' in viz.result): |
| self.mode = 'flexible' |
| self.show_mask = True |
| |
| imgui.same_line() |
| if imgui_utils.button('Fixed area', width=viz.button_w, enabled='image' in viz.result): |
| self.mode = 'fixed' |
| self.show_mask = True |
| |
| imgui.text(' ') |
| imgui.same_line(viz.label_w) |
| if imgui_utils.button('Reset mask', width=viz.button_w, enabled='image' in viz.result): |
| self.mask = torch.ones(self.height, self.width) |
| imgui.same_line() |
| _clicked, self.show_mask = imgui.checkbox('Show mask', self.show_mask) |
|
|
| imgui.text(' ') |
| imgui.same_line(viz.label_w) |
| with imgui_utils.item_width(viz.font_size * 6): |
| changed, self.r_mask = imgui.input_int('Radius', self.r_mask) |
|
|
| imgui.text(' ') |
| imgui.same_line(viz.label_w) |
| with imgui_utils.item_width(viz.font_size * 6): |
| changed, self.lambda_mask = imgui.input_int('Lambda', self.lambda_mask) |
|
|
| self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) |
| if self.defer_frames > 0: |
| self.defer_frames -= 1 |
| viz.args.is_drag = self.is_drag |
| if self.is_drag: |
| self.iteration += 1 |
| viz.args.iteration = self.iteration |
| viz.args.points = [point for point in self.points] |
| viz.args.targets = [point for point in self.targets] |
| viz.args.mask = self.mask |
| viz.args.lambda_mask = self.lambda_mask |
| viz.args.feature_idx = self.feature_idx |
| viz.args.r1 = self.r1 |
| viz.args.r2 = self.r2 |
| viz.args.reset = reset |
|
|
|
|
| |
|
|