Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from options.test_options import TestOptions | |
| from data import create_dataset | |
| from models import create_model | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| import torch | |
| import sys | |
| import matplotlib.pyplot as plt | |
| "python test.py --model test --name selfie2anime --dataroot selfie2anime/testB --num_test 100 --model_suffix '_B' --no_dropout" | |
| title = "MASFNet: Multi-scale Adaptive Sampling Fusion Network for Object Detection in Adverse Weather" | |
| description = "" | |
| article = "" | |
| def reset_interface(): | |
| return gr.update(value=None), gr.update(visible=False) | |
| def resize_image(img): | |
| # 将图片调整为256x256分辨率 | |
| return img.resize((256, 256), Image.BICUBIC) | |
| def check_resolution(img): | |
| # 获取图片分辨率 | |
| width, height = img.size | |
| # 检查分辨率是否符合要求 | |
| if (width == 256 and height == 256) or (width == 64 and height == 64): | |
| return True | |
| else: | |
| return False | |
| def inference(img): | |
| try: | |
| # Debugging: Check if image is correctly received | |
| if img is None: | |
| print("No image received!") | |
| return None | |
| if check_resolution(img)==False: | |
| img = resize_image(img) | |
| import sys | |
| sys.argv = ['--model', '--dataroot', './data/', '--num_test', '1', '--no_dropout'] | |
| # Load options and set them up | |
| opt = TestOptions().parse() | |
| opt.num_threads = 0 | |
| opt.batch_size = 1 | |
| opt.serial_batches = True | |
| opt.no_flip = True | |
| opt.display_id = -1 | |
| opt.name = '' | |
| opt.model_suffix = '_B' | |
| opt.num_test = 1 | |
| opt.no_dropout = True | |
| # Create model and set it up | |
| dataset = create_dataset(opt) | |
| model = create_model(opt) | |
| model.setup(opt) | |
| if opt.eval: | |
| model.eval() | |
| # Convert PIL image to tensor | |
| img_tensor = transforms.ToTensor()(img.convert('RGB')).unsqueeze(0) | |
| img_tensor = img_tensor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Move to GPU if available | |
| # Prepare data for the model | |
| data = {'A':img_tensor,'A_paths':'./data/'} | |
| model.set_input(data) | |
| model.test() | |
| # Get the output visuals | |
| img_out = model.get_current_visuals() | |
| output_img_tensor = img_out.get('fake') | |
| print(f'type of output_img_tensor: {type(img_out)}') | |
| if output_img_tensor is None: | |
| print("No output from model!") | |
| return None | |
| if isinstance(output_img_tensor, torch.Tensor): | |
| # 将张量转换回PIL图像 | |
| output_img = output_img_tensor.squeeze(0).cpu().detach().numpy().transpose(1, 2, 0) | |
| output_img = (output_img * 0.5 + 0.5) * 255 # 假设输出在[-1, 1]之间标准化 | |
| output_img = output_img.astype('uint8') | |
| output_img = Image.fromarray(output_img) | |
| print(f'type if output_img_tensor: {type(output_img_tensor)}') | |
| return output_img | |
| else: | |
| print(f"意外的输出类型: {type(output_img_tensor)}") | |
| return None | |
| except Exception as e: | |
| print(f"Error during inference: {e}") | |
| return None | |
| example_images = [ | |
| "img/1.png" | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"### {title}") | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image(type="pil", label="Upload an Image") | |
| submit_btn = gr.Button("Submit...") | |
| with gr.Column(): | |
| output = gr.Image(type="pil", label="Prediction Result") | |
| submit_btn.click(fn=inference, inputs=img_input, outputs=output) | |
| demo.load(reset_interface, None, output) | |
| gr.Examples( | |
| examples=example_images, | |
| inputs=img_input, | |
| ) | |
| demo.launch() | |