Spaces:
Running
Running
| import numpy as np | |
| import gradio as gr | |
| import cv2 | |
| import os | |
| import argparse | |
| from inference import Predictor | |
| import io | |
| #from black import to_black | |
| # os.system("wget https://huggingface.co/YANGYYYY/cartoonize/tree/main/GeneratorV2_train_photo_Hayao_init.pt") | |
| # if os.path.exists("GeneratorV2_train_photo_Hayao_init.pt"): | |
| # print("下载成功!") | |
| # else: | |
| # print("下载失败!") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--weight', type=str, default='GeneratorV2_train_photo_Hayao_init.pt') | |
| parser.add_argument('--device', type=str, default='cpu', help='Device, cuda or cpu') | |
| return parser.parse_args() | |
| def parse_args_video(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--weight', type=str, default='GeneratorV2_train_photo_Hayao_init.pt') | |
| parser.add_argument('--src', type=str, default='dataset/video/花.mp4', help='Path to input video') | |
| parser.add_argument('--out', type=str, default='dataset/video_Hayao/hua_hayao.mp4', help='Path to save new video') | |
| parser.add_argument('--batch-size', type=int, default=4) | |
| parser.add_argument('--start', type=int, default=0, help='Start time of video (second)') | |
| parser.add_argument('--end', type=int, default=10, help='End time of video (second), 0 if not set') | |
| return parser.parse_args() | |
| def transfer(image, transfer_style): | |
| if transfer_style == "Hayao": | |
| #output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 转换为灰度图像 | |
| #os.system("wget https://huggingface.co/YANGYYYY/cartoonize/resolve/main/GeneratorV2_train_photo_Hayao_init.pt") | |
| args = parse_args() | |
| predictor = Predictor(args.weight, args.device) | |
| anime_img = predictor.transform_image(image) | |
| return anime_img | |
| elif transfer_style == "Shinkai": | |
| args = parse_args() | |
| args.weight = 'GeneratorV2_train_photo_Shinkai_init.pt' | |
| predictor = Predictor(args.weight, args.device) | |
| anime_img = predictor.transform_image(image) | |
| return anime_img | |
| elif transfer_style == "Kon Satoshi": | |
| args = parse_args() | |
| args.weight = 'GeneratorV2_train_photo_Paprika_init.pt' | |
| predictor = Predictor(args.weight, args.device) | |
| anime_img = predictor.transform_image(image) | |
| return anime_img | |
| else: | |
| return image | |
| def transfer_video(video_input, transfer_style): | |
| if transfer_style == "Hayao": | |
| #output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 转换为灰度图像 | |
| #os.system("wget https://huggingface.co/YANGYYYY/cartoonize/resolve/main/GeneratorV2_train_photo_Hayao_init.pt") | |
| args = parse_args_video() | |
| args.src = video_input | |
| args.out = "video.mp4" | |
| Predictor(args.weight).transform_video(args.src, args.out, args.batch_size, start=args.start, end=args.end) | |
| return args.out | |
| #anime_video = Predictor(args.weight).transform_video(video, args.batch_size, args.start, args.end) | |
| #return anime_video | |
| elif transfer_style == "Shinkai": | |
| args = parse_args_video() | |
| args.weight = 'GeneratorV2_train_photo_Shinkai_init.pt' | |
| args.src = video_input | |
| args.out = "video.mp4" | |
| Predictor(args.weight).transform_video(args.src, args.out, args.batch_size, start=args.start, end=args.end) | |
| return args.out | |
| elif transfer_style == "Kon Satoshi": | |
| args = parse_args_video() | |
| args.weight = 'GeneratorV2_train_photo_Paprika_init.pt' | |
| args.src = video_input | |
| args.out = "video.mp4" | |
| Predictor(args.weight).transform_video(args.src, args.out, args.batch_size, start=args.start, end=args.end) | |
| return args.out | |
| else: | |
| return 0 | |
| def clear_output(input_widget): | |
| input_widget = np.array([]) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("Transfer image or video files using this demo.") | |
| with gr.Tabs(): | |
| with gr.TabItem("Transfer Image"): | |
| with gr.Row(): | |
| image_input = gr.Image() | |
| image_output = gr.Image() | |
| with gr.Row(): | |
| image_dropdown = gr.Dropdown(label="Transfer Style",choices=["Hayao", "Shinkai", "Kon Satoshi"]) | |
| image_button = gr.Button("Transfer") | |
| clear_image_button = gr.Button("Clear") | |
| with gr.TabItem("Transfer Video"): | |
| with gr.Row(): | |
| video_input = gr.Video() | |
| video_output = gr.Video() | |
| with gr.Row(): | |
| video_dropdown = gr.Dropdown(label="Transfer Style",choices=["Hayao", "Shinkai", "Kon Satoshi"]) | |
| video_button = gr.Button("Transfer") | |
| clear_video_button = gr.Button("Clear") | |
| image_button.click(transfer, inputs=[image_input,image_dropdown], outputs=image_output) | |
| video_button.click(transfer_video, inputs=[video_input,video_dropdown],outputs=video_output) | |
| clear_image_button.click(clear_output, inputs=image_input,outputs=image_output) | |
| clear_video_button.click(clear_output, inputs=video_input,outputs=video_output) | |
| demo.launch() | |
| # 启动接口 | |
| #demo.launch(server_name='127.0.0.1',server_port=7788) | |