Spaces:
Runtime error
Runtime error
| # Imports | |
| import os | |
| import copy | |
| import torch | |
| import glob | |
| import pandas as pd | |
| import pickle | |
| from xml.dom import minidom | |
| from svgpathtools import svg2paths2 | |
| from svgpathtools import wsvg | |
| import sys | |
| sys.path.append(os.getcwd()) | |
| from src.preprocessing.deepsvg.deepsvg_svglib.svg import SVG | |
| from src.preprocessing.deepsvg.deepsvg_config import config_hierarchical_ordered | |
| from src.preprocessing.deepsvg.deepsvg_utils import train_utils | |
| from src.preprocessing.deepsvg.deepsvg_utils import utils | |
| from src.preprocessing.deepsvg.deepsvg_dataloader import svg_dataset | |
| # ---- Methods for embedding logos ---- | |
| def compute_embedding_folder(folder_path: str, model_path: str, save: str = None) -> pd.DataFrame: | |
| data_list = [] | |
| for file in os.listdir(folder_path): | |
| print('File: ' + file) | |
| try: | |
| embedding = compute_embedding(os.path.join(folder_path, file), model_path) | |
| embedding['filename'] = file | |
| data_list.append(embedding) | |
| except: | |
| print('Embedding failed') | |
| print('Concatenating') | |
| data = pd.concat(data_list) | |
| if not save == None: | |
| output = open(os.path.join(save, 'svg_embedding_5000.pkl'), 'wb') | |
| pickle.dump(data, output) | |
| output.close() | |
| return data | |
| def compute_embedding(path: str, model_path: str, save: str = None) -> pd.DataFrame: | |
| # Convert all primitives to SVG paths - TODO text | |
| paths, attributes, svg_attributes = svg2paths2(path) # In previous project, this is performed at the end | |
| wsvg(paths, attributes=attributes, svg_attributes=svg_attributes, filename=path) | |
| svg = SVG.load_svg(path) | |
| svg.normalize() # Using DeepSVG normalize instead of expanding viewbox - TODO check is this equal? | |
| svg_str = svg.to_str() | |
| # Assign animation id to every path - TODO this changes the original logo! | |
| document = minidom.parseString(svg_str) | |
| paths = document.getElementsByTagName('path') | |
| for i in range(len(paths)): | |
| paths[i].setAttribute('animation_id', str(i)) | |
| with open(path, 'wb') as svg_file: | |
| svg_file.write(document.toxml(encoding='iso-8859-1')) | |
| # Decompose SVGs | |
| decomposed_svgs = {} | |
| for i in range(len(paths)): | |
| doc_temp = copy.deepcopy(document) | |
| paths_temp = doc_temp.getElementsByTagName('path') | |
| current_path = paths_temp[i] | |
| # Iteratively choose path i and remove all others | |
| remove_temp = paths_temp[:i] + paths_temp[i+1:] | |
| for path in remove_temp: | |
| if not path.parentNode.nodeName == 'clipPath': | |
| path.parentNode.removeChild(path) | |
| # Check for style attributes; add in case there are none | |
| if len(current_path.getAttribute('style')) <= 0: | |
| current_path.setAttribute('stroke', 'black') | |
| current_path.setAttribute('stroke-width', '2') | |
| id = current_path.getAttribute('animation_id') | |
| decomposed_svgs[id] = doc_temp.toprettyxml(encoding='iso-8859-1') | |
| doc_temp.unlink() | |
| #print(decomposed_svgs) | |
| meta = {} | |
| for id in decomposed_svgs: | |
| svg_d_str = decomposed_svgs[id] | |
| # Load into SVG and canonicalize | |
| current_svg = SVG.from_str(svg_d_str) | |
| # Canonicalize | |
| current_svg.canonicalize() # Applies DeepSVG canonicalize; previously custom methods were used | |
| decomposed_svgs[id] = current_svg.to_str() | |
| if not os.path.exists('data'): | |
| os.mkdir('data') | |
| if not os.path.exists('data/temp_svg'): | |
| os.mkdir('data/temp_svg') | |
| with open(('data/temp_svg/path_' + str(id)) + '.svg', 'w') as svg_file: | |
| svg_file.write(decomposed_svgs[id]) | |
| # Collect metadata | |
| len_groups = [path_group.total_len() for path_group in current_svg.svg_path_groups] | |
| start_pos = [path_group.svg_paths[0].start_pos for path_group in current_svg.svg_path_groups] | |
| try: | |
| total_len = sum(len_groups) | |
| nb_groups = len(len_groups) | |
| max_len_group = max(len_groups) | |
| except: | |
| total_len = 0 | |
| nb_groups = 0 | |
| max_len_group = 0 | |
| meta[id] = { | |
| 'id': id, | |
| 'total_len': total_len, | |
| 'nb_groups': nb_groups, | |
| 'len_groups': len_groups, | |
| 'max_len_group': max_len_group, | |
| 'start_pos': start_pos | |
| } | |
| metadata = pd.DataFrame(meta.values()) | |
| #print(metadata) | |
| if not os.path.exists('data/metadata'): | |
| os.mkdir('data/metadata') | |
| metadata.to_csv('data/metadata/metadata.csv', index=False) | |
| # Load pretrained DeepSVG model | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| cfg = config_hierarchical_ordered.Config() | |
| model = cfg.make_model().to(device) | |
| train_utils.load_model(model_path, model) | |
| model.eval() | |
| # Load dataset | |
| cfg.data_dir = 'data/temp_svg/' | |
| cfg.meta_filepath = 'data/metadata/metadata.csv' | |
| dataset = svg_dataset.load_dataset(cfg) | |
| svg_files = glob.glob('data/temp_svg/*.svg') | |
| #print(svg_files) | |
| svg_list = [] | |
| for svg_file in svg_files: | |
| id = svg_file.split('/')[1].split('_')[1].split('.')[0] | |
| # Preprocessing | |
| svg = SVG.load_svg(svg_file) | |
| svg = dataset.simplify(svg) | |
| svg = dataset.preprocess(svg, augment=False) | |
| data = dataset.get(svg=svg) | |
| # Get embedding | |
| model_args = utils.batchify((data[key] for key in cfg.model_args), device) | |
| with torch.no_grad(): | |
| z = model(*model_args, encode_mode=True).cpu().numpy()[0][0][0] | |
| dict_data = { | |
| 'animation_id': id, | |
| 'embedding': z | |
| } | |
| svg_list.append(dict_data) | |
| data = pd.DataFrame.from_records(svg_list, index='animation_id')['embedding'].apply(pd.Series) | |
| data.reset_index(level=0, inplace=True) | |
| data.dropna(inplace=True) | |
| data.reset_index(drop=True, inplace=True) | |
| if not save == None: | |
| output = open(os.path.join(save, 'svg_embedding_5000.pkl'), 'wb') | |
| pickle.dump(data, output) | |
| output.close() | |
| print('Embedding computed') | |
| return data | |
| #compute_embedding_folder('data/raw_dataset', 'src/preprocessing/deepsvg/deepsvg_models/deepSVG_hierarchical_ordered.pth.tar', 'data/embedding') |