Spaces:
Sleeping
Sleeping
File size: 5,456 Bytes
48f5a5d 14af88b 48f5a5d 14af88b 48f5a5d a6e2df6 48f5a5d 3358ce6 48f5a5d 14af88b 48f5a5d 3358ce6 14af88b 48f5a5d 14af88b 48f5a5d 14af88b 48f5a5d 14af88b 48f5a5d 14af88b 48f5a5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import json
import base64
import requests
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from dotenv import load_dotenv
from huggingface_hub import HfFileSystem, hf_hub_download
def request_to_endpoint (request:dict, endpoint_id='s4i4um7xakaq37'):
api_key = os.getenv("RUNPOD_API_KEY")
url = f"https://api.runpod.ai/v2/{endpoint_id}/runsync"
headers = {
"accept": "application/json",
"authorization": api_key,
"content-type": "application/json"
}
response = requests.post(url, headers=headers, data=json.dumps(request))
return response
# Собирает только sdxl версии
def prepare_characters(hf_path="OnMoon/loras"):
fs = HfFileSystem()
files = fs.ls(hf_path, detail=False)
character_names = []
character_configs = {}
for file in files:
if file.endswith(".safetensors") and file.startswith(f"{hf_path}/sdxl_"):
character = file[len(f"{hf_path}/sdxl_"): -len(".safetensors")]
character_names.append(character)
elif file.endswith(".json") and file.startswith(f"{hf_path}/sdxl_"):
character = file[len(f"{hf_path}/sdxl_"): -len(".json")]
with fs.open(file, 'r', encoding='utf-8') as file_json:
character_configs[character] = json.load(file_json)
return character_names, character_configs
character_names, character_configs = prepare_characters()
def process_input (name, scale, triggers, tech_prompt, tech_negative_prompt, prompts_count, *args):
# Если хотим задать свои триггерные слова или технический промпт
triggers = triggers if triggers != "" else ",".join(character_configs[name]["trigger_words"])
tech_prompt = tech_prompt if tech_prompt != "" else character_configs[name]["tech_prompt"]
tech_negative_prompt = tech_negative_prompt if tech_negative_prompt != "" else character_configs[name]["tech_negative_prompt"]
model = character_configs[name]["model"]
model['loras'] = {name: scale}
params = character_configs[name]["params"]
params["cross_attention_kwargs"]['scale'] = scale
boxes = list(args)
prompts = []
negative_prompts = []
for n in range(prompts_count):
prompts.append(f'{triggers}, {tech_prompt}, {boxes[n]}')
negative_prompts.append(f"{tech_negative_prompt}")
request_data = {
"input": {
"model": model,
"params": params,
"prompt": prompts,
"negative_prompt": negative_prompts,
"height": 1216,
"width": 832,
}
}
response = request_to_endpoint(request_data)
images = []
for base64_string in response.json()['output']['images']:
img = Image.open(BytesIO(base64.b64decode(base64_string)))
images.append(img)
gallery = [[images[i], f"{prompts[i]}"] for i in range(prompts_count)]
return gallery
######################################################
# ____ _ _ _ #
# / ___|_ __ __ _ __| (_) ___ / \ _ __ _ __ #
# | | _| '__/ _` |/ _` | |/ _ \ / _ \ | '_ \| '_ \ #
# | |_| | | | (_| | (_| | | (_) / ___ \| |_) | |_) | #
# \____|_| \__,_|\__,_|_|\___/_/ \_\ .__/| .__/ #
# |_| |_| #
############################################################################################################
with gr.Blocks() as demo:
with gr.Group():
name = gr.Radio(
character_names,
label="Select character:",
interactive=True,
visible=True,
)
scale = gr.Slider(
minimum=0,
maximum=2.0,
value=0.75,
step=0.01,
label="Selected LoRA scale:",
interactive=True,
)
with gr.Accordion(open=False):
triggers = gr.Textbox(label=f"Trigger words:")
tech_prompt = gr.Textbox(label=f"Technical prompt:")
tech_negative_prompt = gr.Textbox(label=f"Negative technical prompt:")
prompts_count = gr.State(1)
with gr.Group():
add_btn = gr.Button("Add prompt")
del_btn = gr.Button("Delete prompt")
add_btn.click(lambda x: x + 1, prompts_count, prompts_count)
del_btn.click(lambda x: x - 1, prompts_count, prompts_count)
@gr.render(inputs=prompts_count)
def render_count(count):
boxes = []
for i in range(count):
with gr.Group():
prompt = gr.Textbox(key=str(i), label=f"Prompt {i+1}")
boxes.append(prompt)
generate_btn.click(
process_input,
[name, scale, triggers, tech_prompt, tech_negative_prompt, prompts_count]+boxes,
output
)
generate_btn = gr.Button("Generate!")
output = gr.Gallery(
label="Generation results:",
object_fit="contain",
height="auto",
)
demo.launch()
############################################################################################################ |