honeysuckle commited on
Commit
48f5a5d
·
1 Parent(s): f580a15
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import requests
5
+ import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+
8
+ from PIL import Image
9
+ from io import BytesIO
10
+ from huggingface_hub import HfFileSystem, hf_hub_download
11
+
12
+
13
+ def request_to_endpoint (request:dict, endpoint_id='4556prwagxw5co'):
14
+ api_key = "A6DTPK0VM02LEBYFSVJKKFZ9MTZQ4ED1CMBK0OE2"
15
+ url = f"https://api.runpod.ai/v2/{endpoint_id}/runsync"
16
+ headers = {
17
+ "accept": "application/json",
18
+ "authorization": api_key,
19
+ "content-type": "application/json"
20
+ }
21
+ response = requests.post(url, headers=headers, data=json.dumps(request))
22
+
23
+ return response
24
+
25
+ # Собирает только sdxl версии
26
+ def prepare_characters(hf_path="OnMoon/loras"):
27
+ fs = HfFileSystem()
28
+
29
+ files = fs.ls(hf_path, detail=False)
30
+ character_names = []
31
+ character_configs = {}
32
+ for file in files:
33
+ if file.endswith(".safetensors") and file.startswith(f"{hf_path}/sdxl_"):
34
+ character = file[len(f"{hf_path}/sdxl_"): -len(".safetensors")]
35
+ character_names.append(character)
36
+ elif file.endswith(".json") and file.startswith(f"{hf_path}/sdxl_"):
37
+ character = file[len(f"{hf_path}/sdxl_"): -len(".json")]
38
+ with fs.open(file, 'r', encoding='utf-8') as file_json:
39
+ character_configs[character] = json.load(file_json)
40
+
41
+ return character_names, character_configs
42
+
43
+ character_names, character_configs = prepare_characters()
44
+
45
+ # Invoke endpoint
46
+ request_to_endpoint({"input": {"prompt": ""}})
47
+
48
+
49
+ def process_input (name, scale, triggers, tech_prompt, tech_negative_prompt, prompts_count, *args):
50
+ boxes = list(args)
51
+ prompts = []
52
+ negative_prompts = []
53
+
54
+ # Если хотим задать свои триггерные слова или технический промпт
55
+ triggers = triggers if triggers != "" else ",".join(character_configs[name]["trigger_words"])
56
+ tech_prompt = tech_prompt if tech_prompt != "" else character_configs[name]["tech_prompt"]
57
+ tech_negative_prompt = tech_negative_prompt if tech_negative_prompt != "" else character_configs[name]["tech_negative_prompt"]
58
+ model = character_configs[name]["model"]
59
+ model['loras'] = {name: scale}
60
+ params = character_configs[name]["params"]
61
+ for n in range(prompts_count):
62
+ prompts.append(f'{triggers}, {tech_prompt}, {boxes[2*n]}')
63
+ negative_prompts.append(f"{tech_negative_prompt}, {boxes[2*n + 1]}")
64
+
65
+ request_data = {
66
+ "input": {
67
+ "mode": "inference",
68
+ "model": model,
69
+ "params": params,
70
+ "prompt": prompts,
71
+ "negative_prompt": negative_prompts,
72
+ "height": 1216,
73
+ "width": 832,
74
+ }
75
+ }
76
+
77
+ response = request_to_endpoint(request_data)
78
+
79
+ images = []
80
+ for base64_string in response.json()['output']['images']:
81
+ img = Image.open(BytesIO(base64.b64decode(base64_string)))
82
+ images.append(img)
83
+
84
+ gallery = [[images[i], f"{prompts[i]}"] for i in range(prompts_count)]
85
+
86
+ return gallery
87
+
88
+
89
+
90
+
91
+
92
+
93
+ ######################################################
94
+ # ____ _ _ _ #
95
+ # / ___|_ __ __ _ __| (_) ___ / \ _ __ _ __ #
96
+ # | | _| '__/ _` |/ _` | |/ _ \ / _ \ | '_ \| '_ \ #
97
+ # | |_| | | | (_| | (_| | | (_) / ___ \| |_) | |_) | #
98
+ # \____|_| \__,_|\__,_|_|\___/_/ \_\ .__/| .__/ #
99
+ # |_| |_| #
100
+ ############################################################################################################
101
+ with gr.Blocks() as demo:
102
+ with gr.Group():
103
+ name = gr.Radio(
104
+ character_names,
105
+ label="Select character:",
106
+ interactive=True,
107
+ visible=True,
108
+ )
109
+
110
+ with gr.Accordion(open=False):
111
+ scale = gr.Slider(0, 2.0, 0.01, label=f"{name} scale:", value=0.75, interactive=True)
112
+ triggers = gr.Textbox(label=f"Trigger words:")
113
+ tech_prompt = gr.Textbox(label=f"Technical prompt:")
114
+ tech_negative_prompt = gr.Textbox(label=f"Negative technical prompt:")
115
+
116
+ prompts_count = gr.State(1)
117
+
118
+ with gr.Row():
119
+ with gr.Column():
120
+ add_btn = gr.Button("Add prompt")
121
+ del_btn = gr.Button("Delete prompt")
122
+
123
+ add_btn.click(lambda x: x + 1, prompts_count, prompts_count)
124
+ del_btn.click(lambda x: x - 1, prompts_count, prompts_count)
125
+
126
+ @gr.render(inputs=prompts_count)
127
+ def render_count(count):
128
+ boxes = []
129
+ for i in range(count):
130
+ with gr.Group():
131
+ prompt = gr.Textbox(key=i, label=f"Prompt {i+1}")
132
+ negative_prompt = gr.Textbox(key=i, label=f"Negative prompt {i+1}")
133
+
134
+ boxes.append(prompt)
135
+ boxes.append(negative_prompt)
136
+
137
+ generate_btn.click(
138
+ process_input,
139
+ [name, scale, triggers, tech_prompt, tech_negative_prompt, prompts_count]+boxes,
140
+ output
141
+ )
142
+
143
+ generate_btn = gr.Button("Generate!")
144
+
145
+ output = gr.Gallery(
146
+ label="Generation results:",
147
+ object_fit="contain",
148
+ height="auto",
149
+ )
150
+
151
+ demo.launch()
152
+ ############################################################################################################