Spaces:
Configuration error
Configuration error
| import gradio as gr | |
| import os | |
| import sys | |
| base_path = os.path.expanduser('~') | |
| sys.path.append(os.path.join(base_path, 'Er0mangaSeg/')) | |
| sys.path.append(os.path.join(base_path, 'Er0mangaSeg/demo')) | |
| from image_demo_tta import init_seg_model, inference_tta | |
| sys.path.append(os.path.join(base_path, 'Er0mangaInpaint/')) | |
| sys.path.append(os.path.join(base_path, 'Er0mangaInpaint/bin')) | |
| from uncen import init_inpaint_model, inpaint | |
| import time | |
| import numpy as np | |
| import cv2 | |
| import shutil | |
| import torch | |
| if torch.cuda.is_available(): | |
| print('GPU found!') | |
| device = 'cuda:0' | |
| else: | |
| print('GPU not found! Using CPU') | |
| device = 'cpu' | |
| config = os.path.join(base_path, 'Er0mangaSeg/configs/convnext/convnext_h.py') | |
| checkpoint = os.path.join(base_path, 'Er0mangaSeg/pretrained/convnext_1024_iter_400.pth') | |
| model_seg = init_seg_model(config, checkpoint, device=device) | |
| print('Segmentation initialized') | |
| inp_model_path = os.path.join(base_path, 'Er0mangaInpaint/pretrained/00-30-09') | |
| model_inp = init_inpaint_model(inp_model_path) | |
| print('Inpainting initialized') | |
| def proc(input_img): | |
| try: | |
| s = time.time() | |
| out_mask, raw_mask = inference_tta(model_seg, input_img) | |
| out_mask = np.dstack([out_mask, out_mask, out_mask]) | |
| raw_mask = np.dstack([raw_mask, raw_mask, raw_mask]) | |
| output_img, out_dbg = inpaint(model_inp, input_img, out_mask) | |
| e = time.time() | |
| print(f"proc_time: {e-s:.2f}") | |
| return output_img#, raw_mask | |
| except Exception as e: | |
| raise gr.Error(e) | |
| def proc_batch(batch): | |
| res = [] | |
| try: | |
| s = time.time() | |
| out_p = os.path.dirname(batch[0][0]) | |
| salt = str(np.random.randint(1e10)) | |
| out_p_d = os.path.join(out_p, '__salt_img__'+salt) | |
| out_p_m = os.path.join(out_p, '__salt_mask__'+salt) | |
| os.mkdir(out_p_d) | |
| os.mkdir(out_p_m) | |
| for i in range(len(batch)): | |
| input_path = batch[i][0] | |
| inp_name = os.path.basename(input_path) | |
| input_img = cv2.cvtColor(cv2.imread(input_path), cv2.COLOR_BGR2RGB) | |
| out_mask, raw_mask = inference_tta(model_seg, input_img) | |
| out_mask = np.dstack([out_mask, out_mask, out_mask]) | |
| raw_mask = np.dstack([raw_mask, raw_mask, raw_mask]) | |
| output_img, out_dbg = inpaint(model_inp, input_img, out_mask) | |
| out_path_img = os.path.join(out_p_d, inp_name) | |
| out_path_mask = os.path.join(out_p_m, inp_name+'.png') | |
| cv2.imwrite(out_path_img, cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)) | |
| cv2.imwrite(out_path_mask, raw_mask) | |
| res.append(out_path_img) | |
| ar_path = os.path.join(out_p, 'output') | |
| shutil.make_archive(ar_path, 'zip', out_p_d) | |
| ar_path_m = os.path.join(out_p, 'output_mask') | |
| shutil.make_archive(ar_path_m, 'zip', out_p_m) | |
| e = time.time() | |
| print(f"batch proc_time: {e-s:.2f}") | |
| return res, ar_path + '.zip', ar_path_m + '.zip' | |
| except Exception as e: | |
| raise gr.Error(e) | |
| demo1 = gr.Interface(proc, gr.Image(), gr.Image(format='png'), delete_cache=(7200, 7200), allow_flagging='never') | |
| demo2 = gr.Interface(proc_batch, gr.Gallery(), [gr.Gallery(value='str', format='png'), gr.File(), gr.File()], delete_cache=(7200, 7200), allow_flagging='never') | |
| demo = gr.TabbedInterface([demo1, demo2], ["Single image processing", "Batch processing (experimental)"]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name='0.0.0.0', server_port=7860) | |