File size: 2,957 Bytes
79a4708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21e4e69
 
 
ea44f4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import gradio as gr
from PIL import Image
from IPython.display import display

def load_images_and_texts(source_file, target_file):
    images = []
    texts = []
    for filename in os.listdir(source_file):
        if filename.endswith(".jpg"): 
            image_path = os.path.join(source_file, filename)
            image = Image.open(image_path)
            thumbnail_size = (256, 256)
            image.thumbnail(thumbnail_size)
            image_dict = {"name": filename, "image": image}
            images.append(image_dict)
            
            txt_path = os.path.join(source_file, filename[:-4]+".txt")
            with open(txt_path) as f:
                txt = f.read()
                texts.append(txt)
    
    def click_image(btn):
        img_path = btn.metadata["name"]
        img = Image.open(os.path.join(source_file, img_path))
        display(img)
        
        txt_path = os.path.join(source_file, img_path[:-4]+".txt")
        with open(txt_path) as f:
            txt = f.read()
            text_area.value = txt
    
    def generate_prompt(btn):
        prompt = prompt_area.value
        if selected_image:
            selected_image_path = os.path.join(source_file, selected_image["name"])
            selected_image.save(os.path.join(target_file, selected_image["name"]))
            with open(os.path.join(target_file, selected_image["name"][:-4]+".txt"), "w") as f:
                f.write(prompt)
    
    selected_image = None
    
    image_buttons = [gr.outputs.Button(label=image_dict["name"], 
                                       metadata={"name": image_dict["name"], "image": image_dict["image"]}, 
                                       type="image") for image_dict in images]
    
    image_box = gr.outputs.GridBox(children=image_buttons, layout=gr.inputs.Layout(grid_template_columns="repeat(5, 1fr)"))
    text_area = gr.outputs.Textbox()
    prompt_area = gr.inputs.Textbox(label="Generate Prompt")
    generate_button = gr.outputs.Button(label="Generate", type="button")
    generate_button.action = generate_prompt
    
    def select_image(btn):
        nonlocal selected_image
        selected_image = {"name": btn.metadata["name"], "image": btn.metadata["image"]}
    
    for button in image_buttons:
        button.action = click_image
        button.style["cursor"] = "pointer"
        button.style["margin"] = "5px"
        button.style["border"] = "2px solid white"
        button.style["border-radius"] = "5px"
        button.style["box-shadow"] = "0px 0px 4px 4px rgba(0, 0, 0, 0.1)"
        button.style["background-color"] = "transparent"
        button.style["padding"] = 0
        button.style["width"] = "100%"
        button.style["height"] = "100%"
    
    return gr.Interface([image_box, text_area, prompt_area, generate_button], "float", source_file=gr.inputs.Folder(), target_file=gr.inputs.Folder()).launch()


load_images_and_texts(gr.inputs.File(),gr.outputs.File())