honeysuckle
fixes
14af88b
raw
history blame
5.46 kB
import os
import json
import base64
import requests
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from dotenv import load_dotenv
from huggingface_hub import HfFileSystem, hf_hub_download
def request_to_endpoint (request:dict, endpoint_id='s4i4um7xakaq37'):
api_key = os.getenv("RUNPOD_API_KEY")
url = f"https://api.runpod.ai/v2/{endpoint_id}/runsync"
headers = {
"accept": "application/json",
"authorization": api_key,
"content-type": "application/json"
}
response = requests.post(url, headers=headers, data=json.dumps(request))
return response
# Собирает только sdxl версии
def prepare_characters(hf_path="OnMoon/loras"):
fs = HfFileSystem()
files = fs.ls(hf_path, detail=False)
character_names = []
character_configs = {}
for file in files:
if file.endswith(".safetensors") and file.startswith(f"{hf_path}/sdxl_"):
character = file[len(f"{hf_path}/sdxl_"): -len(".safetensors")]
character_names.append(character)
elif file.endswith(".json") and file.startswith(f"{hf_path}/sdxl_"):
character = file[len(f"{hf_path}/sdxl_"): -len(".json")]
with fs.open(file, 'r', encoding='utf-8') as file_json:
character_configs[character] = json.load(file_json)
return character_names, character_configs
character_names, character_configs = prepare_characters()
def process_input (name, scale, triggers, tech_prompt, tech_negative_prompt, prompts_count, *args):
# Если хотим задать свои триггерные слова или технический промпт
triggers = triggers if triggers != "" else ",".join(character_configs[name]["trigger_words"])
tech_prompt = tech_prompt if tech_prompt != "" else character_configs[name]["tech_prompt"]
tech_negative_prompt = tech_negative_prompt if tech_negative_prompt != "" else character_configs[name]["tech_negative_prompt"]
model = character_configs[name]["model"]
model['loras'] = {name: scale}
params = character_configs[name]["params"]
params["cross_attention_kwargs"]['scale'] = scale
boxes = list(args)
prompts = []
negative_prompts = []
for n in range(prompts_count):
prompts.append(f'{triggers}, {tech_prompt}, {boxes[n]}')
negative_prompts.append(f"{tech_negative_prompt}")
request_data = {
"input": {
"model": model,
"params": params,
"prompt": prompts,
"negative_prompt": negative_prompts,
"height": 1216,
"width": 832,
}
}
response = request_to_endpoint(request_data)
images = []
for base64_string in response.json()['output']['images']:
img = Image.open(BytesIO(base64.b64decode(base64_string)))
images.append(img)
gallery = [[images[i], f"{prompts[i]}"] for i in range(prompts_count)]
return gallery
######################################################
# ____ _ _ _ #
# / ___|_ __ __ _ __| (_) ___ / \ _ __ _ __ #
# | | _| '__/ _` |/ _` | |/ _ \ / _ \ | '_ \| '_ \ #
# | |_| | | | (_| | (_| | | (_) / ___ \| |_) | |_) | #
# \____|_| \__,_|\__,_|_|\___/_/ \_\ .__/| .__/ #
# |_| |_| #
############################################################################################################
with gr.Blocks() as demo:
with gr.Group():
name = gr.Radio(
character_names,
label="Select character:",
interactive=True,
visible=True,
)
scale = gr.Slider(
minimum=0,
maximum=2.0,
value=0.75,
step=0.01,
label="Selected LoRA scale:",
interactive=True,
)
with gr.Accordion(open=False):
triggers = gr.Textbox(label=f"Trigger words:")
tech_prompt = gr.Textbox(label=f"Technical prompt:")
tech_negative_prompt = gr.Textbox(label=f"Negative technical prompt:")
prompts_count = gr.State(1)
with gr.Group():
add_btn = gr.Button("Add prompt")
del_btn = gr.Button("Delete prompt")
add_btn.click(lambda x: x + 1, prompts_count, prompts_count)
del_btn.click(lambda x: x - 1, prompts_count, prompts_count)
@gr.render(inputs=prompts_count)
def render_count(count):
boxes = []
for i in range(count):
with gr.Group():
prompt = gr.Textbox(key=str(i), label=f"Prompt {i+1}")
boxes.append(prompt)
generate_btn.click(
process_input,
[name, scale, triggers, tech_prompt, tech_negative_prompt, prompts_count]+boxes,
output
)
generate_btn = gr.Button("Generate!")
output = gr.Gallery(
label="Generation results:",
object_fit="contain",
height="auto",
)
demo.launch()
############################################################################################################