Spaces:
Running
Running
| import os | |
| import uuid | |
| import argparse | |
| argparser = argparse.ArgumentParser() | |
| argparser.add_argument("--port", type=int, default=1239, help="Port number for the local server") | |
| argparser.add_argument("--cuda_device", type=str, default='0', help="Cuda devices to use. Default is 0") | |
| argparser.add_argument("--static_folder", type=str, default='static', help="Folder to store static files") | |
| args = argparser.parse_args() | |
| os.environ["CUDA_VISIBLE_DEVICES"] = '' | |
| import gradio as gr | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from LInK.CAD import create_3d_html | |
| import numpy as np | |
| from LInK.CurveUtils import uniformize | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from LInK.Solver import solve_rev_vectorized_batch_CPU | |
| # turn off gradient computation | |
| torch.set_grad_enabled(False) | |
| results = np.load('alpha_res.npy',allow_pickle=True) | |
| alphabet_test = np.load('alphabet.npy') | |
| zs_ = np.load('alpha_z.npy',allow_pickle=True) | |
| curves__ = torch.tensor(alphabet_test).float() | |
| curves__ = curves__ - curves__.mean(1).unsqueeze(1) | |
| max_idx = torch.square(curves__).sum(-1).argmax(dim=1) | |
| theta = torch.atan2(curves__[torch.arange(curves__.shape[0]),max_idx,1],curves__[torch.arange(curves__.shape[0]),max_idx,0]).numpy() | |
| curves_ = [] | |
| for i in range(len(results)): | |
| curves_.append(results[i][-1]) | |
| curves_ = torch.tensor(curves_).float() | |
| curves_ = uniformize(curves_,200) | |
| curves_ = curves_ - curves_.mean(1).unsqueeze(1) | |
| max_idx = torch.square(curves_).sum(-1).argmax(dim=1) | |
| theta2 = torch.atan2(curves_[torch.arange(curves_.shape[0]),max_idx,1],curves_[torch.arange(curves_.shape[0]),max_idx,0]).numpy() | |
| alphas = [] | |
| letter_heights = [] | |
| letter_centers = [] | |
| letter_widths = [] | |
| for i in range(len(results)): | |
| A, x0, node_type, start_theta, end_theta, tr = results[i][0] | |
| alpha = theta[i] - theta2[i] | |
| if i == 21: | |
| alpha -= np.pi/2.5 | |
| if i == 7: | |
| alpha -= np.pi/3 | |
| if i == 4: | |
| alpha += np.pi/36 | |
| alphas.append(alpha) | |
| R = np.array([[np.cos(alpha), -np.sin(alpha)],[np.sin(alpha), np.cos(alpha)]]).squeeze() | |
| transformed_curve = (R@results[i][-1].T).T | |
| CD,OD,_ = results[i][1] | |
| sol = solve_rev_vectorized_batch_CPU(A[None],x0[None],node_type[None],np.linspace(start_theta,end_theta,2000))[0] | |
| sol_curve = (R@sol[-1].T).T | |
| n_left = len(results) - i | |
| n_left_row = 10 - i%10 | |
| letter_heights.append(transformed_curve[:,1].max()-transformed_curve[:,1].min()) | |
| letter_widths.append(transformed_curve[:,0].max()-transformed_curve[:,0].min()) | |
| letter_centers.append([(transformed_curve[:,0].max() + transformed_curve[:,0].min())/2,(transformed_curve[:,1].max() + transformed_curve[:,1].min())/2]) | |
| alphas = np.array(alphas) | |
| letter_heights = np.array(letter_heights) | |
| letter_centers = np.array(letter_centers) | |
| letter_widths = np.array(letter_widths) | |
| alphabet_dict = {'A':0,'B':1,'C':2,'D':3,'E':4,'F':5,'G':6,'H':7,'I':8,'J':9,'K':10,'L':11,'M':12,'N':13,'O':14,'P':16,'Q':15,'R':17,'S':18,'T':19,'U':20,'V':21,'W':22,'X':23,'Y':24,'Z':25} | |
| def create_mech(target_text): | |
| target_text = target_text.replace(' ','').upper() | |
| target_height = 1. | |
| spacing = 0.2 | |
| letters = [alphabet_dict[l] for l in target_text] | |
| translations = [] | |
| scaling = [] | |
| transformed_curves = [] | |
| mechs = [] | |
| total_size = 0 | |
| for i,l in enumerate(letters): | |
| A, x0, node_type, start_theta, end_theta, tr = results[l][0] | |
| alpha = alphas[l] | |
| R = np.array([[np.cos(alpha), -np.sin(alpha)],[np.sin(alpha), np.cos(alpha)]]).squeeze() | |
| transformed_curve = (R@results[l][-1].T).T | |
| s = target_height/letter_heights[l] | |
| scaling.append(s) | |
| if i>0: | |
| trans = [translations[-1][0] + letter_widths[letters[i-1]]/2 * scaling[i-1] + letter_widths[l]/2 * s + spacing , 0] | |
| else: | |
| trans = [letter_widths[l]/2 * s ,0] | |
| translations.append(trans) | |
| transformed_curves.append(s*(transformed_curve - letter_centers[l]) + trans) | |
| mechs.append([A,s*((R@x0.T).T - letter_centers[l]) + trans,node_type,start_theta+alpha,end_theta+alpha]) | |
| total_size += A.shape[0] | |
| A_all = np.zeros((total_size,total_size)) | |
| x0_all = np.zeros((total_size,2)) | |
| node_type_all = np.zeros((total_size,1)) | |
| current_count = 0 | |
| sols = [] | |
| highlights = [] | |
| zs = [] | |
| for i,m in enumerate(mechs): | |
| A, x0, node_type, start_theta, end_theta = m | |
| A_all[current_count:current_count+A.shape[0],current_count:current_count+A.shape[0]] = A | |
| # x0_all[current_count:current_count+A.shape[0]] = x0 | |
| node_type_all[current_count:current_count+A.shape[0]] = node_type | |
| if i ==0: | |
| highlights.append(A.shape[0]-1) | |
| else: | |
| highlights.append(A.shape[0]+highlights[-1]) | |
| sol = solve_rev_vectorized_batch_CPU(A[None],x0[None],node_type[None],np.linspace(start_theta,end_theta,100))[0] | |
| sols.append(sol.transpose(1,0,2)) | |
| x0_all[current_count:current_count+A.shape[0]] = sol[:,0,:] | |
| current_count += A.shape[0] | |
| z = zs_[letters[i]] | |
| zs.append(z + zs[-1].max() + 1 if i>0 else z) | |
| sols = np.concatenate(sols,axis=1) | |
| zs = np.concatenate(zs) | |
| uuid_ = str(uuid.uuid4()) | |
| create_3d_html(A_all, x0_all, node_type_all, zs, np.concatenate([sols.transpose(1,0,2),sols.transpose(1,0,2)[:,::-1,:]],1), template_path = f'./static/animation.html', save_path=f'./static/{uuid_}.html', highlights=highlights) | |
| return gr.HTML(f'<iframe width="100%" height="800px" src="file=static/{uuid_}.html"></iframe>',label="3D Plot",elem_classes="plot3d") | |
| gr.set_static_paths(paths=[Path(f'./{args.static_folder}')]) | |
| with gr.Blocks() as block: | |
| with gr.Row(): | |
| with gr.Column(): | |
| text = gr.Textbox(label="Enter a word (spaces will be ignored)", value='DECODE') | |
| btn = gr.Button(value="Create Mechanism", variant="primary") | |
| plot_3d = gr.HTML('<iframe width="100%" height="800px" src="file=static/filler.html"></iframe>',label="3D Plot",elem_classes="plot3d") | |
| event1 = btn.click(create_mech, inputs=[text], outputs=[plot_3d]) | |
| block.launch() | |