wangbaoyuan commited on
Commit
479af2c
·
verified ·
1 Parent(s): 9d85c29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -102
app.py CHANGED
@@ -1,102 +1,102 @@
1
- import gradio as gr
2
- from options.test_options import TestOptions
3
- from data import create_dataset
4
- from models import create_model
5
- from PIL import Image
6
- import torchvision.transforms as transforms
7
- import torch
8
- import sys
9
- import matplotlib.pyplot as plt
10
- "python test.py --model test --name selfie2anime --dataroot selfie2anime/testB --num_test 100 --model_suffix '_B' --no_dropout"
11
-
12
-
13
- title = "MASFNet: Multi-scale Adaptive Sampling Fusion Network for Object Detection in Adverse Weather"
14
- description = ""
15
- article = ""
16
-
17
- def reset_interface():
18
- return gr.update(value=None), gr.update(visible=False)
19
-
20
- def inference(img):
21
- try:
22
- # Debugging: Check if image is correctly received
23
- if img is None:
24
- print("No image received!")
25
- return None
26
- import sys
27
- sys.argv = ['--model', '--dataroot', '/home/data/luhaoxiang/wby/cyclegan/img/', '--num_test', '1', '--no_dropout']
28
-
29
- # Load options and set them up
30
- opt = TestOptions().parse()
31
- opt.num_threads = 0
32
- opt.batch_size = 1
33
- opt.serial_batches = True
34
- opt.no_flip = True
35
- opt.display_id = -1
36
- opt.name = 'selfie2anime'
37
- opt.model_suffix = '_B'
38
- opt.num_test = 1
39
- opt.no_dropout = True
40
-
41
- # Create model and set it up
42
- dataset = create_dataset(opt)
43
- model = create_model(opt)
44
- model.setup(opt)
45
- if opt.eval:
46
- model.eval()
47
- # Convert PIL image to tensor
48
- img_tensor = transforms.ToTensor()(img.convert('RGB')).unsqueeze(0)
49
- img_tensor = img_tensor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Move to GPU if available
50
-
51
- # Prepare data for the model
52
- data = {'A':img_tensor,'A_paths':'/home/data/luhaoxiang/wby/cyclegan/img/'}
53
- model.set_input(data)
54
- model.test()
55
-
56
- # Get the output visuals
57
- img_out = model.get_current_visuals()
58
- output_img_tensor = img_out.get('fake')
59
- print(f'type of output_img_tensor: {type(img_out)}')
60
- if output_img_tensor is None:
61
- print("No output from model!")
62
- return None
63
-
64
- if isinstance(output_img_tensor, torch.Tensor):
65
- # 将张量转换回PIL图像
66
- output_img = output_img_tensor.squeeze(0).cpu().detach().numpy().transpose(1, 2, 0)
67
- output_img = (output_img * 0.5 + 0.5) * 255 # 假设输出在[-1, 1]之间标准化
68
- output_img = output_img.astype('uint8')
69
- output_img = Image.fromarray(output_img)
70
- print(f'type if output_img_tensor: {type(output_img_tensor)}')
71
- return output_img
72
- else:
73
- print(f"意外的输出类型: {type(output_img_tensor)}")
74
- return None
75
-
76
- except Exception as e:
77
- print(f"Error during inference: {e}")
78
- return None
79
-
80
- example_images = [
81
- "img/1.png"
82
- ]
83
-
84
- with gr.Blocks() as demo:
85
- gr.Markdown(f"### {title}")
86
- gr.Markdown(description)
87
-
88
- with gr.Row():
89
- with gr.Column():
90
- img_input = gr.Image(type="pil", label="Upload an Image")
91
- submit_btn = gr.Button("Submit")
92
- with gr.Column():
93
- output = gr.Image(type="pil", label="Prediction Result")
94
-
95
- submit_btn.click(fn=inference, inputs=img_input, outputs=output)
96
- demo.load(reset_interface, None, output)
97
- gr.Examples(
98
- examples=example_images,
99
- inputs=img_input,
100
- )
101
-
102
- demo.launch()
 
1
+ import gradio as gr
2
+ from options.test_options import TestOptions
3
+ from data import create_dataset
4
+ from models import create_model
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ import torch
8
+ import sys
9
+ import matplotlib.pyplot as plt
10
+ "python test.py --model test --name selfie2anime --dataroot selfie2anime/testB --num_test 100 --model_suffix '_B' --no_dropout"
11
+
12
+
13
+ title = "MASFNet: Multi-scale Adaptive Sampling Fusion Network for Object Detection in Adverse Weather"
14
+ description = ""
15
+ article = ""
16
+
17
+ def reset_interface():
18
+ return gr.update(value=None), gr.update(visible=False)
19
+
20
+ def inference(img):
21
+ try:
22
+ # Debugging: Check if image is correctly received
23
+ if img is None:
24
+ print("No image received!")
25
+ return None
26
+ import sys
27
+ sys.argv = ['--model', '--dataroot', '/home/data/luhaoxiang/wby/cyclegan/img/', '--num_test', '1', '--no_dropout']
28
+
29
+ # Load options and set them up
30
+ opt = TestOptions().parse()
31
+ opt.num_threads = 0
32
+ opt.batch_size = 1
33
+ opt.serial_batches = True
34
+ opt.no_flip = True
35
+ opt.display_id = -1
36
+ opt.name = 'selfie2anime'
37
+ opt.model_suffix = '_B'
38
+ opt.num_test = 1
39
+ opt.no_dropout = True
40
+
41
+ # Create model and set it up
42
+ dataset = create_dataset(opt)
43
+ model = create_model(opt)
44
+ model.setup(opt)
45
+ if opt.eval:
46
+ model.eval()
47
+ # Convert PIL image to tensor
48
+ img_tensor = transforms.ToTensor()(img.convert('RGB')).unsqueeze(0)
49
+ img_tensor = img_tensor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Move to GPU if available
50
+
51
+ # Prepare data for the model
52
+ data = {'A':img_tensor,'A_paths':'/home/data/luhaoxiang/wby/cyclegan/img/'}
53
+ model.set_input(data)
54
+ model.test()
55
+
56
+ # Get the output visuals
57
+ img_out = model.get_current_visuals()
58
+ output_img_tensor = img_out.get('fake')
59
+ print(f'type of output_img_tensor: {type(img_out)}')
60
+ if output_img_tensor is None:
61
+ print("No output from model!")
62
+ return None
63
+
64
+ if isinstance(output_img_tensor, torch.Tensor):
65
+ # 将张量转换回PIL图像
66
+ output_img = output_img_tensor.squeeze(0).cpu().detach().numpy().transpose(1, 2, 0)
67
+ output_img = (output_img * 0.5 + 0.5) * 255 # 假设输出在[-1, 1]之间标准化
68
+ output_img = output_img.astype('uint8')
69
+ output_img = Image.fromarray(output_img)
70
+ print(f'type if output_img_tensor: {type(output_img_tensor)}')
71
+ return output_img
72
+ else:
73
+ print(f"意外的输出类型: {type(output_img_tensor)}")
74
+ return None
75
+
76
+ except Exception as e:
77
+ print(f"Error during inference: {e}")
78
+ return None
79
+
80
+ example_images = [
81
+ "img/1.png"
82
+ ]
83
+
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown(f"### {title}")
86
+ gr.Markdown(description)
87
+
88
+ with gr.Row():
89
+ with gr.Column():
90
+ img_input = gr.Image(type="pil", label="Upload an Image")
91
+ submit_btn = gr.Button("Submit...")
92
+ with gr.Column():
93
+ output = gr.Image(type="pil", label="Prediction Result")
94
+
95
+ submit_btn.click(fn=inference, inputs=img_input, outputs=output)
96
+ demo.load(reset_interface, None, output)
97
+ gr.Examples(
98
+ examples=example_images,
99
+ inputs=img_input,
100
+ )
101
+
102
+ demo.launch()