DrumJu commited on
Commit
182a35c
·
1 Parent(s): b171fe5

Add application file

Browse files
Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import tempfile
5
+ from PIL import Image
6
+ import time
7
+ temp_dir=tempfile.gettempdir()
8
+
9
+ def process_images(condition_images, input_images):
10
+ start = time.time()
11
+ output_images = []
12
+ for img in input_images:
13
+ if img is not None:
14
+ output_images.append(img)
15
+ pth_path = os.path.join(temp_dir, "output.pth")
16
+ temp_data = {"test": "test"}
17
+ torch.save(temp_data, pth_path)
18
+ end=time.time()
19
+ process_time = f"{end-start:.2f} s"
20
+ return output_images, pth_path, process_time
21
+
22
+ with gr.Blocks() as demo:
23
+ gr.Markdown("Title")
24
+ with gr.Row():
25
+ with gr.Group():
26
+ condition_inputs = gr.Files(label="Condition Img", file_types=[".png", ".jpg", ".jpeg"], type='filepath')
27
+ input_images = gr.Files(label="Input Img", file_types=[".png", ".jpg", ".jpeg"], type='filepath')
28
+ with gr.Group():
29
+ output_gallery = gr.Gallery(label="Output Img", show_label=True, columns=2)
30
+ pth_output = gr.File(label="下载 PTH 文件")
31
+ process_time = gr.Textbox(label="处理用时", type="text",interactive=False)
32
+ button1 = gr.Button("上传并处理")
33
+ def func1(condition_files, input_files):
34
+ condition_imgs = [Image.open(f) for f in condition_files] if condition_files else []
35
+ input_imgs = [Image.open(f) for f in input_files] if input_files else []
36
+ return process_images(condition_imgs, input_imgs)
37
+
38
+ button1.click(
39
+ fn=func1,
40
+ inputs=[condition_inputs, input_images],
41
+ outputs=[output_gallery, pth_output, process_time]
42
+ )
43
+
44
+ demo.launch()