Spaces:
Paused
Paused
| import logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| import os | |
| # os.environ['CUDA_VISIBLE_DEVICES'] = '1' | |
| import cv2 | |
| import imageio | |
| import time | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import plotly.express as px | |
| import torch | |
| import dash | |
| from dash import Dash, Input, Output, dcc, html, State | |
| from dash.exceptions import PreventUpdate | |
| from .self_prompting import grounding_dino_prompt | |
| def mark_image(_img, points): | |
| assert(len(points) > 0) | |
| img = _img.copy() | |
| r = 10 | |
| mark_color = np.array([255, 0, 0]).reshape(1, 1, 3) | |
| for i in range(len(points)): | |
| point = points[i] | |
| img[point[1]-r:point[1]+r+1, point[0]-r:point[0]+r+1] = mark_color | |
| return img | |
| def draw_figure(fig, title, animation_frame=None): | |
| fig = px.imshow(fig, animation_frame=animation_frame) | |
| if animation_frame is not None: | |
| # fig.update_layout(sliders = [{'visible': False}]) | |
| fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 33 | |
| fig.update_layout(title_text=title, showlegend=False) | |
| fig.update_xaxes(showticklabels=False) | |
| fig.update_yaxes(showticklabels=False) | |
| return fig | |
| class Sam3dGUI: | |
| def __init__(self, Seg3d, debug=False): | |
| ctx = { | |
| 'num_clicks': 0, | |
| 'click': [], | |
| 'cur_img': None, | |
| 'btn_clear': 0, | |
| 'btn_text': 0, | |
| 'prompt_type': 'point', | |
| 'show_rgb': False | |
| } | |
| self.ctx = ctx | |
| self.Seg3d = Seg3d | |
| self.debug = debug | |
| self.train_idx = 0 | |
| def run(self): | |
| init_rgb = self.Seg3d.init_model() | |
| self.ctx['cur_img'] = init_rgb | |
| self.run_app(sam_pred=self.Seg3d.predictor, ctx=self.ctx, init_rgb=init_rgb) | |
| def run_app(self, sam_pred, ctx, init_rgb): | |
| ''' | |
| run dash app | |
| ''' | |
| def query(points=None, text=None): | |
| with torch.no_grad(): | |
| if text is None: | |
| input_point = points | |
| input_label = np.ones(len(input_point)) | |
| masks, scores, logits = sam_pred.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| multimask_output=True, | |
| ) | |
| elif points is None: | |
| input_boxes = grounding_dino_prompt(ctx['cur_img'], text) | |
| boxes = torch.tensor(input_boxes)[0:1].cuda() | |
| transformed_boxes = sam_pred.transform.apply_boxes_torch(boxes, ctx['cur_img'].shape[:2]) | |
| masks, scores, logits = sam_pred.predict_torch( | |
| point_coords=None, | |
| point_labels=None, | |
| boxes=transformed_boxes, | |
| multimask_output=True, | |
| ) | |
| masks = masks[0].cpu().numpy() | |
| else: | |
| raise NotImplementedError | |
| fig1 = (255*masks[0, :, :, None]*0.6 + ctx['cur_img']*0.4).astype(np.uint8) | |
| fig2 = (255*masks[1, :, :, None]*0.6 + ctx['cur_img']*0.4).astype(np.uint8) | |
| fig3 = (255*masks[2, :, :, None]*0.6 + ctx['cur_img']*0.4).astype(np.uint8) | |
| fig1 = draw_figure(fig1, 'mask0') | |
| fig2 = draw_figure(fig2, 'mask1') | |
| fig3 = draw_figure(fig3, 'mask2') | |
| if text is None: | |
| fig0 = mark_image(ctx['cur_img'], points) | |
| else: | |
| fig0 = ctx['cur_img'] | |
| fig0 = draw_figure(fig0, 'original_image') | |
| return masks, fig0, fig1, fig2, fig3 | |
| # _, fig0, fig1, fig2, fig3, desc = query(np.array([[100, 100], [101, 101]])) | |
| self.ctx['fig0'] = draw_figure(init_rgb, 'original_image') | |
| self.ctx['fig1'] = draw_figure(np.zeros_like(init_rgb), 'mask0') | |
| self.ctx['fig2'] = draw_figure(np.zeros_like(init_rgb), 'mask1') | |
| self.ctx['fig3'] = draw_figure(np.zeros_like(init_rgb), 'mask2') | |
| self.ctx['fig_seg_rgb'] = draw_figure(np.zeros_like(init_rgb), 'Masked image in Training') | |
| self.ctx['fig_sam_mask'] = draw_figure(np.zeros_like(init_rgb), 'SAM Mask with Prompts in Training') | |
| self.ctx['fig_masked_rgb'] = draw_figure(np.zeros_like(init_rgb), 'Masked RGB') | |
| self.ctx['fig_seged_rgb'] = draw_figure(np.zeros_like(init_rgb), 'Seged RGB') | |
| app = dash.Dash( | |
| __name__, meta_tags=[{"name": "viewport", "content": "width=device-width"}] | |
| ) | |
| app.layout = html.Div( | |
| style={"height": "100%"}, | |
| children=[ | |
| html.Div(className="container", children=[ | |
| html.Div(className="row", children=[ | |
| html.Div(className="two columns",style={"padding-bottom": "5%"},children=[ | |
| html.Div([html.H3(['SAM Init'])]), | |
| html.Br(), | |
| html.H5('Prompt Type:'), | |
| html.Div([ | |
| dcc.Dropdown( | |
| id = 'prompt_type', | |
| options = [{'label': 'Points', 'value': 'point'}, | |
| {'label': 'Text', 'value': 'text'},], | |
| value = 'point'), | |
| html.Div(id = 'output-prompt_type') | |
| ]), | |
| html.Br(), | |
| html.H5('Point Prompts:'), | |
| html.Button('Clear Points', id='btn-nclicks-clear', n_clicks=0), | |
| html.Br(), | |
| html.H5('Text Prompt:'), | |
| html.Div([ | |
| dcc.Input(id='input-text-state', type='text', value='none'), | |
| html.Button(id='submit-button-state', n_clicks=0, children='Generate'), | |
| html.Div(id='output-state-text') | |
| ]), | |
| html.Br(), | |
| html.H5('Please select the mask:'), | |
| html.Div([ | |
| dcc.RadioItems(['mask0', 'mask1', 'mask2'], id='sel_mask_id', value=None) | |
| ], style={'display': 'flex'}), | |
| html.Br(), | |
| html.H5(id='container-sel-mask'), | |
| ]), | |
| html.Div(className="ten columns",children=[ | |
| html.Div(children=[ | |
| dcc.Graph(id='main_image', figure=self.ctx['fig0']) | |
| ], style={'display': 'inline-block', 'width': '40%'}), | |
| html.Div(children=[ | |
| dcc.Graph(id='mask0', figure=self.ctx['fig1']) | |
| ], style={'display': 'inline-block', 'width': '40%'}), | |
| html.Div(children=[ | |
| dcc.Graph(id='mask1', figure=self.ctx['fig2']) | |
| ], style={'display': 'inline-block', 'width': '40%'}), | |
| html.Div(children=[ | |
| dcc.Graph(id='mask2', figure=self.ctx['fig3']) | |
| ], style={'display': 'inline-block', 'width': '40%'}), | |
| ]) | |
| ]) | |
| ]), | |
| html.Div(className="container", children=[ | |
| html.Div(className="row", children=[ | |
| html.Div(className="two columns",style={"padding-bottom": "5%"},children=[ | |
| html.Div([html.H3(['SA3D Training'])]), | |
| html.Br(), | |
| html.Button('Start Training', id='btn-nclicks-training', n_clicks=0), | |
| html.Div(id='container-button-training', style={'display': 'inline-block'}), | |
| ]), | |
| html.Div(className="ten columns",children=[ | |
| html.Div(children=[ | |
| dcc.Graph(id='seg_rgb', figure=self.ctx['fig_seg_rgb']) | |
| ], style={'display': 'inline-block', 'width': '40%'}), | |
| html.Div(children=[ | |
| dcc.Graph(id='sam_mask', figure=self.ctx['fig_sam_mask']) | |
| ], style={'display': 'inline-block', 'width': '40%'}), | |
| ]), | |
| dcc.Interval( | |
| id='interval-component', | |
| interval=1*1000, # in milliseconds | |
| n_intervals=0), | |
| ]) | |
| ]), | |
| html.Div(className="container", children=[ | |
| html.Div(className="row", children=[ | |
| html.Div(className="two columns",style={"padding-bottom": "5%"},children=[ | |
| html.Div([html.H3(['SA3D Rendering Results'])]), | |
| html.Br(), | |
| ]), | |
| html.Div(className="ten columns",children=[ | |
| html.Div(children=[ | |
| dcc.Graph(id='masked_rgb', figure=self.ctx['fig_masked_rgb']) | |
| ], style={'display': 'inline-block', 'width': '40%'}), | |
| html.Div(children=[ | |
| dcc.Graph(id='seged_rgb', figure=self.ctx['fig_seged_rgb']) | |
| ], style={'display': 'inline-block', 'width': '40%'}), | |
| ]), | |
| ]) | |
| ]) | |
| ]) | |
| def update_prompt_type(value): | |
| self.ctx['prompt_type'] = value | |
| if value != 'point': | |
| ctx['click'] = [] | |
| ctx['num_clicks'] = 0 | |
| return f"Type {value} is chosen" | |
| def update_prompt(clickData, btn_point, btn_text, text): | |
| ''' | |
| update mask | |
| ''' | |
| if self.ctx['prompt_type'] == 'point': | |
| if clickData is None and btn_point == self.ctx['btn_clear']: | |
| raise PreventUpdate | |
| if btn_point > self.ctx['btn_clear']: | |
| self.ctx['btn_clear'] += 1 | |
| ctx['click'] = [] | |
| ctx['num_clicks'] = 0 | |
| return self.ctx['fig0'], self.ctx['fig1'], self.ctx['fig2'], self.ctx['fig3'], 'none' | |
| ctx['num_clicks'] += 1 | |
| ctx['click'].append(np.array([clickData['points'][0]['x'], clickData['points'][0]['y']])) | |
| ctx['saved_click'] = np.stack(ctx['click']) | |
| masks, fig0, fig1, fig2, fig3 = query(ctx['saved_click']) | |
| ctx['masks'] = masks | |
| return fig0, fig1, fig2, fig3, 'none' | |
| elif self.ctx['prompt_type'] == 'text': | |
| if btn_text > self.ctx['btn_text']: | |
| self.ctx['btn_text'] += 1 | |
| self.ctx['text'] = text | |
| masks, fig0, fig1, fig2, fig3 = query(points=None, text=text) | |
| ctx['masks'] = masks | |
| return fig0, fig1, fig2, fig3, u''' | |
| Input text is "{}" | |
| '''.format(text) | |
| else: | |
| raise PreventUpdate | |
| else: | |
| raise NotImplementedError | |
| def update_graph(radio_items): | |
| if radio_items == 'mask0': | |
| ctx['select_mask_id'] = 0 | |
| return html.Div("you select mask0") | |
| elif radio_items == 'mask1': | |
| ctx['select_mask_id'] = 1 | |
| return html.Div("you select mask1") | |
| elif radio_items == 'mask2': | |
| ctx['select_mask_id'] = 2 | |
| return html.Div("you select mask2") | |
| else: | |
| raise PreventUpdate | |
| def displaySeg(n): | |
| if self.ctx['show_rgb']: | |
| self.ctx['show_rgb'] = False | |
| fig_seg_rgb = draw_figure(self.ctx['fig_seg_rgb'], 'Masked image in Training') | |
| fig_sam_mask = draw_figure(self.ctx['fig_sam_mask'], 'SAM Mask with Prompts in Training') | |
| return fig_seg_rgb, fig_sam_mask | |
| else: | |
| raise PreventUpdate | |
| def start_training(btn): | |
| if btn < 1: | |
| return html.Div("Press to start training"), self.ctx['fig_masked_rgb'], self.ctx['fig_seged_rgb'] | |
| else: | |
| # optim in the first view | |
| self.Seg3d.train_step(self.train_idx, sam_mask=ctx['masks'][ctx['select_mask_id']]) | |
| self.train_idx += 1 | |
| # cross-view training | |
| while True: | |
| rgb, sam_prompt, is_finished = self.Seg3d.train_step(self.train_idx) | |
| self.train_idx += 1 | |
| self.ctx['fig_seg_rgb'] = rgb | |
| self.ctx['fig_sam_mask'] = sam_prompt | |
| self.ctx['show_rgb'] = True | |
| if is_finished: | |
| break | |
| self.Seg3d.save_ckpt() | |
| masked_rgb, seged_rgb = self.Seg3d.render_test() | |
| fig_masked_rgb = draw_figure(masked_rgb, 'Masked RGB', animation_frame=0) | |
| fig_seged_rgb = draw_figure(seged_rgb, 'Seged RGB', animation_frame=0) | |
| return html.Div("Train Stage Finished! Press Ctrl+C to Exit!"), fig_masked_rgb, fig_seged_rgb | |
| app.run_server(debug=self.debug) | |
| if __name__ == '__main__': | |
| from segment_anything import (SamAutomaticMaskGenerator, SamPredictor, | |
| sam_model_registry) | |
| class Sam_predictor(): | |
| def __init__(self, device): | |
| sam_checkpoint = "./dependencies/sam_ckpt/sam_vit_h_4b8939.pth" | |
| model_type = "vit_h" | |
| self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device) | |
| self.predictor = SamPredictor(self.sam) | |
| print('sam inited!') | |
| # pass | |
| def forward(self, points, multimask_output=True, return_logits=False): | |
| # self.predictor.set_image(image) | |
| # input_point = np.array([[x, y], [x + 1, y + 1]]) # TODO, add interactive mode | |
| input_point = points | |
| input_label = np.ones(len(input_point)) | |
| masks, scores, logits = self.predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| multimask_output=multimask_output, | |
| return_logits=return_logits | |
| ) | |
| return masks | |
| image = cv2.cvtColor(cv2.imread('data/nerf_llff_data(NVOS)/fern/images_4/image000.png'), cv2.COLOR_BGR2RGB) | |
| sam_pred = Sam_predictor(torch.device('cuda')) | |
| sam_pred.predictor.set_image(image) | |
| video = np.stack(imageio.mimread('logs/llff/fern/render_train_coarse_segmentation_gui/video.rgbseg_gui.mp4')) | |
| gui = Sam3dGUI(None, debug=True) | |
| gui.ctx['cur_img'] = image | |
| gui.ctx['video'] = video | |
| gui.run_app(sam_pred.predictor, gui.ctx, image) | |