File size: 3,764 Bytes
fb59cb8
 
 
 
 
 
aa7c58e
7c078a3
c43f64e
01b28b7
7d3b3b8
 
c43f64e
8cdd8e4
de8aa71
 
 
 
 
 
7d3b3b8
aa7c58e
 
 
 
da72583
aa7c58e
 
7d3b3b8
 
c410097
 
7d3b3b8
aa7c58e
8cdd8e4
da72583
aa7c58e
 
 
 
 
7d3b3b8
 
da72583
7d3b3b8
a9d8072
c43f64e
 
aa7c58e
a9d8072
c43f64e
aa7c58e
 
 
 
7d3b3b8
aa7c58e
58653a8
aa7c58e
7d3b3b8
 
c410097
7d3b3b8
 
a83750e
bbe49e5
7d3b3b8
c43f64e
de8aa71
c43f64e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbe49e5
c43f64e
aa7c58e
 
7d3b3b8
aa7c58e
de8aa71
 
 
 
 
 
 
aa7c58e
8cdd8e4
7d3b3b8
c43f64e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python

from __future__ import annotations

import gradio as gr
import PIL.Image
import zipfile
from genTag import genTag
from checkIgnore import is_ignore, ignore2

def predict(image: PIL.Image.Image):
    result_threshold = genTag(image, 0.5)
    return result_threshold, ignore2, """<div></div>"""

def predict_api(image: PIL.Image.Image):
    result_threshold = genTag(image, 0.5)
    result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 2)}
    tag = ', '.join(result_filter.keys())
    return str(tag)

def predict_batch(zip_file, progress=gr.Progress()):
    result = ''
    with zipfile.ZipFile(zip_file) as zf:
        for file in progress.tqdm(zf.namelist()):
            print(file)
            if file.endswith(".png") or file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".webp"):
                image_file = zf.open(file)
                image = PIL.Image.open(image_file)
                image = image.convert("RGBA")
                result_threshold = genTag(image, 0.5)
                result_filter = {key: value for key, value in result_threshold.items() if not is_ignore(key, 2)}
                tag = ', '.join(result_filter.keys())
                result = result + str(file) + '\n' + str(tag) + '\n\n'
    return result

with gr.Blocks(head_paths="head.html") as demo:
    with gr.Tab(label='Single'):
        with gr.Row():
            with gr.Column(scale=1):
                image = gr.Image(label='Upload a image',
                                 type='pil',
                                 elem_classes='m5dd_image',
                                 image_mode="RGBA",
                                 show_fullscreen_button=False,
                                 sources=["upload", "clipboard"])
                result_text = gr.HTML(value="""<div id="m5dd_result"></div>""", padding=False)
                result_hide = gr.JSON(visible=False)
                result_hide2 = gr.JSON(visible=False)
            with gr.Column(scale=2):
                result_html = gr.HTML(value="""<div id="m5dd_list"></div>""", padding=False)
                result_loading = gr.HTML(value="""<div></div>""", elem_classes='m5dd_html', padding=False)
    with gr.Tab(label='Batch'):
        with gr.Row():
            with gr.Column(scale=1):
                batch_file = gr.File(label="Upload a ZIP file containing images",
                                     file_types=['.zip'])
                run_button2 = gr.Button('Run')
                run_button_api = gr.Button(value='Run', visible=False)
            with gr.Column(scale=2):
                result_text2 = gr.Textbox(lines=20,
                                          max_lines=20,
                                          label='Result',
                                          show_copy_button=True,
                                          autoscroll=False)
    image.upload(
        fn=predict,
        inputs=[image],
        outputs=[result_hide, result_hide2, result_loading],
        api_name=False,
        js="""
        (image) => {
            window.m5Func.clear()
            return image;
        }
        """,
    ).success(
        fn=None,
        inputs=[result_hide, result_hide2],
        js="""
        (result, ignore) => {
            window.m5Func.refresh(result, ignore)
            return [result, ignore];
        }
        """,
    )

    run_button2.click(
        fn=predict_batch,
        inputs=[batch_file],
        outputs=[result_text2],
        api_name=False,
    )
    run_button_api.click(
        fn=predict_api,
        inputs=[image],
        outputs=[result_text2],
        api_name='predict',
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()