File size: 5,456 Bytes
48f5a5d
 
 
 
 
 
 
 
 
14af88b
48f5a5d
 
14af88b
 
48f5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6e2df6
48f5a5d
 
 
 
 
 
 
3358ce6
48f5a5d
 
 
 
 
 
 
14af88b
48f5a5d
 
 
3358ce6
14af88b
 
 
 
48f5a5d
14af88b
 
48f5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14af88b
 
 
 
 
 
 
 
 
48f5a5d
 
 
 
 
 
 
14af88b
 
48f5a5d
 
14af88b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48f5a5d
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import json
import base64
import requests
import gradio as gr
import matplotlib.pyplot as plt

from PIL import Image
from io import BytesIO
from dotenv import load_dotenv
from huggingface_hub import HfFileSystem, hf_hub_download

def request_to_endpoint (request:dict, endpoint_id='s4i4um7xakaq37'):
    api_key = os.getenv("RUNPOD_API_KEY")
    url = f"https://api.runpod.ai/v2/{endpoint_id}/runsync"
    headers = {
        "accept": "application/json",
        "authorization": api_key,
        "content-type": "application/json"
    }
    response = requests.post(url, headers=headers, data=json.dumps(request)) 

    return response

# Собирает только sdxl версии
def prepare_characters(hf_path="OnMoon/loras"):
    fs = HfFileSystem()

    files = fs.ls(hf_path, detail=False)
    character_names = []
    character_configs = {}
    for file in files:
        if file.endswith(".safetensors") and file.startswith(f"{hf_path}/sdxl_"):
            character = file[len(f"{hf_path}/sdxl_"): -len(".safetensors")]
            character_names.append(character)
        elif file.endswith(".json") and file.startswith(f"{hf_path}/sdxl_"):
            character = file[len(f"{hf_path}/sdxl_"): -len(".json")]
            with fs.open(file, 'r', encoding='utf-8') as file_json:
                character_configs[character] = json.load(file_json)  
        
    return character_names, character_configs

character_names, character_configs = prepare_characters()


def process_input (name, scale, triggers, tech_prompt, tech_negative_prompt, prompts_count, *args):    
    # Если хотим задать свои триггерные слова или технический промпт
    triggers = triggers if triggers != "" else ",".join(character_configs[name]["trigger_words"])
    tech_prompt = tech_prompt if tech_prompt != "" else character_configs[name]["tech_prompt"]
    tech_negative_prompt = tech_negative_prompt if tech_negative_prompt != "" else character_configs[name]["tech_negative_prompt"]

    model = character_configs[name]["model"]
    model['loras'] = {name: scale}
    params = character_configs[name]["params"]
    params["cross_attention_kwargs"]['scale'] = scale

    boxes = list(args)
    prompts = []
    negative_prompts = []
    for n in range(prompts_count):
        prompts.append(f'{triggers}, {tech_prompt}, {boxes[n]}')
        negative_prompts.append(f"{tech_negative_prompt}")
    
    request_data = {
        "input": {
            "model": model,
            "params": params,
            "prompt": prompts,
            "negative_prompt": negative_prompts,
            "height": 1216,
            "width": 832,
        }
    }
    
    response = request_to_endpoint(request_data)

    images = []
    for base64_string in response.json()['output']['images']:
        img = Image.open(BytesIO(base64.b64decode(base64_string)))
        images.append(img)

    gallery = [[images[i], f"{prompts[i]}"] for i in range(prompts_count)]

    return gallery






                            ######################################################
                            #   ____               _ _         _                 #
                            #  / ___|_ __ __ _  __| (_) ___   / \   _ __  _ __   #
                            # | |  _| '__/ _` |/ _` | |/ _ \ / _ \ | '_ \| '_ \  #
                            # | |_| | | | (_| | (_| | | (_) / ___ \| |_) | |_) | #
                            #  \____|_|  \__,_|\__,_|_|\___/_/   \_\ .__/| .__/  #
                            #                                      |_|   |_|     #
############################################################################################################
with gr.Blocks() as demo:
    with gr.Group():
        name = gr.Radio(
            character_names,
            label="Select character:",
            interactive=True,
            visible=True,
        )

        scale = gr.Slider(
            minimum=0, 
            maximum=2.0,
            value=0.75, 
            step=0.01, 
            label="Selected LoRA scale:", 
            interactive=True,
        )

        with gr.Accordion(open=False):
            triggers = gr.Textbox(label=f"Trigger words:")
            tech_prompt = gr.Textbox(label=f"Technical prompt:")
            tech_negative_prompt = gr.Textbox(label=f"Negative technical prompt:")

    prompts_count = gr.State(1)

    with gr.Group():
        add_btn = gr.Button("Add prompt")
        del_btn = gr.Button("Delete prompt")
            
        add_btn.click(lambda x: x + 1, prompts_count, prompts_count)
        del_btn.click(lambda x: x - 1, prompts_count, prompts_count)

        @gr.render(inputs=prompts_count)
        def render_count(count):
            boxes = []
            for i in range(count):
                with gr.Group():
                    prompt = gr.Textbox(key=str(i), label=f"Prompt {i+1}")
                boxes.append(prompt)
            
            generate_btn.click(
                process_input, 
                [name, scale, triggers, tech_prompt, tech_negative_prompt, prompts_count]+boxes,
                output
            )

    generate_btn = gr.Button("Generate!")

    output = gr.Gallery(
        label="Generation results:",
        object_fit="contain", 
        height="auto",
    )

demo.launch()
############################################################################################################