import os import json import base64 import requests import gradio as gr import matplotlib.pyplot as plt from PIL import Image from io import BytesIO from dotenv import load_dotenv from huggingface_hub import HfFileSystem, hf_hub_download def request_to_endpoint (request:dict, endpoint_id='s4i4um7xakaq37'): api_key = os.getenv("RUNPOD_API_KEY") url = f"https://api.runpod.ai/v2/{endpoint_id}/runsync" headers = { "accept": "application/json", "authorization": api_key, "content-type": "application/json" } response = requests.post(url, headers=headers, data=json.dumps(request)) return response # Собирает только sdxl версии def prepare_characters(hf_path="OnMoon/loras"): fs = HfFileSystem() files = fs.ls(hf_path, detail=False) character_names = [] character_configs = {} for file in files: if file.endswith(".safetensors") and file.startswith(f"{hf_path}/sdxl_"): character = file[len(f"{hf_path}/sdxl_"): -len(".safetensors")] character_names.append(character) elif file.endswith(".json") and file.startswith(f"{hf_path}/sdxl_"): character = file[len(f"{hf_path}/sdxl_"): -len(".json")] with fs.open(file, 'r', encoding='utf-8') as file_json: character_configs[character] = json.load(file_json) return character_names, character_configs character_names, character_configs = prepare_characters() def process_input (name, scale, triggers, tech_prompt, tech_negative_prompt, prompts_count, *args): # Если хотим задать свои триггерные слова или технический промпт triggers = triggers if triggers != "" else ",".join(character_configs[name]["trigger_words"]) tech_prompt = tech_prompt if tech_prompt != "" else character_configs[name]["tech_prompt"] tech_negative_prompt = tech_negative_prompt if tech_negative_prompt != "" else character_configs[name]["tech_negative_prompt"] model = character_configs[name]["model"] model['loras'] = {name: scale} params = character_configs[name]["params"] params["cross_attention_kwargs"]['scale'] = scale boxes = list(args) prompts = [] negative_prompts = [] for n in range(prompts_count): prompts.append(f'{triggers}, {tech_prompt}, {boxes[n]}') negative_prompts.append(f"{tech_negative_prompt}") request_data = { "input": { "model": model, "params": params, "prompt": prompts, "negative_prompt": negative_prompts, "height": 1216, "width": 832, } } response = request_to_endpoint(request_data) images = [] for base64_string in response.json()['output']['images']: img = Image.open(BytesIO(base64.b64decode(base64_string))) images.append(img) gallery = [[images[i], f"{prompts[i]}"] for i in range(prompts_count)] return gallery ###################################################### # ____ _ _ _ # # / ___|_ __ __ _ __| (_) ___ / \ _ __ _ __ # # | | _| '__/ _` |/ _` | |/ _ \ / _ \ | '_ \| '_ \ # # | |_| | | | (_| | (_| | | (_) / ___ \| |_) | |_) | # # \____|_| \__,_|\__,_|_|\___/_/ \_\ .__/| .__/ # # |_| |_| # ############################################################################################################ with gr.Blocks() as demo: with gr.Group(): name = gr.Radio( character_names, label="Select character:", interactive=True, visible=True, ) scale = gr.Slider( minimum=0, maximum=2.0, value=0.75, step=0.01, label="Selected LoRA scale:", interactive=True, ) with gr.Accordion(open=False): triggers = gr.Textbox(label=f"Trigger words:") tech_prompt = gr.Textbox(label=f"Technical prompt:") tech_negative_prompt = gr.Textbox(label=f"Negative technical prompt:") prompts_count = gr.State(1) with gr.Group(): add_btn = gr.Button("Add prompt") del_btn = gr.Button("Delete prompt") add_btn.click(lambda x: x + 1, prompts_count, prompts_count) del_btn.click(lambda x: x - 1, prompts_count, prompts_count) @gr.render(inputs=prompts_count) def render_count(count): boxes = [] for i in range(count): with gr.Group(): prompt = gr.Textbox(key=str(i), label=f"Prompt {i+1}") boxes.append(prompt) generate_btn.click( process_input, [name, scale, triggers, tech_prompt, tech_negative_prompt, prompts_count]+boxes, output ) generate_btn = gr.Button("Generate!") output = gr.Gallery( label="Generation results:", object_fit="contain", height="auto", ) demo.launch() ############################################################################################################