Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import base64 | |
| import requests | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from io import BytesIO | |
| from dotenv import load_dotenv | |
| from huggingface_hub import HfFileSystem, hf_hub_download | |
| def request_to_endpoint (request:dict, endpoint_id='s4i4um7xakaq37'): | |
| api_key = os.getenv("RUNPOD_API_KEY") | |
| url = f"https://api.runpod.ai/v2/{endpoint_id}/runsync" | |
| headers = { | |
| "accept": "application/json", | |
| "authorization": api_key, | |
| "content-type": "application/json" | |
| } | |
| response = requests.post(url, headers=headers, data=json.dumps(request)) | |
| return response | |
| # Собирает только sdxl версии | |
| def prepare_characters(hf_path="OnMoon/loras"): | |
| fs = HfFileSystem() | |
| files = fs.ls(hf_path, detail=False) | |
| character_names = [] | |
| character_configs = {} | |
| for file in files: | |
| if file.endswith(".safetensors") and file.startswith(f"{hf_path}/sdxl_"): | |
| character = file[len(f"{hf_path}/sdxl_"): -len(".safetensors")] | |
| character_names.append(character) | |
| elif file.endswith(".json") and file.startswith(f"{hf_path}/sdxl_"): | |
| character = file[len(f"{hf_path}/sdxl_"): -len(".json")] | |
| with fs.open(file, 'r', encoding='utf-8') as file_json: | |
| character_configs[character] = json.load(file_json) | |
| return character_names, character_configs | |
| character_names, character_configs = prepare_characters() | |
| def process_input (name, scale, triggers, tech_prompt, tech_negative_prompt, prompts_count, *args): | |
| # Если хотим задать свои триггерные слова или технический промпт | |
| triggers = triggers if triggers != "" else ",".join(character_configs[name]["trigger_words"]) | |
| tech_prompt = tech_prompt if tech_prompt != "" else character_configs[name]["tech_prompt"] | |
| tech_negative_prompt = tech_negative_prompt if tech_negative_prompt != "" else character_configs[name]["tech_negative_prompt"] | |
| model = character_configs[name]["model"] | |
| model['loras'] = {name: scale} | |
| params = character_configs[name]["params"] | |
| params["cross_attention_kwargs"]['scale'] = scale | |
| boxes = list(args) | |
| prompts = [] | |
| negative_prompts = [] | |
| for n in range(prompts_count): | |
| prompts.append(f'{triggers}, {tech_prompt}, {boxes[n]}') | |
| negative_prompts.append(f"{tech_negative_prompt}") | |
| request_data = { | |
| "input": { | |
| "model": model, | |
| "params": params, | |
| "prompt": prompts, | |
| "negative_prompt": negative_prompts, | |
| "height": 1216, | |
| "width": 832, | |
| } | |
| } | |
| response = request_to_endpoint(request_data) | |
| images = [] | |
| for base64_string in response.json()['output']['images']: | |
| img = Image.open(BytesIO(base64.b64decode(base64_string))) | |
| images.append(img) | |
| gallery = [[images[i], f"{prompts[i]}"] for i in range(prompts_count)] | |
| return gallery | |
| ###################################################### | |
| # ____ _ _ _ # | |
| # / ___|_ __ __ _ __| (_) ___ / \ _ __ _ __ # | |
| # | | _| '__/ _` |/ _` | |/ _ \ / _ \ | '_ \| '_ \ # | |
| # | |_| | | | (_| | (_| | | (_) / ___ \| |_) | |_) | # | |
| # \____|_| \__,_|\__,_|_|\___/_/ \_\ .__/| .__/ # | |
| # |_| |_| # | |
| ############################################################################################################ | |
| with gr.Blocks() as demo: | |
| with gr.Group(): | |
| name = gr.Radio( | |
| character_names, | |
| label="Select character:", | |
| interactive=True, | |
| visible=True, | |
| ) | |
| scale = gr.Slider( | |
| minimum=0, | |
| maximum=2.0, | |
| value=0.75, | |
| step=0.01, | |
| label="Selected LoRA scale:", | |
| interactive=True, | |
| ) | |
| with gr.Accordion(open=False): | |
| triggers = gr.Textbox(label=f"Trigger words:") | |
| tech_prompt = gr.Textbox(label=f"Technical prompt:") | |
| tech_negative_prompt = gr.Textbox(label=f"Negative technical prompt:") | |
| prompts_count = gr.State(1) | |
| with gr.Group(): | |
| add_btn = gr.Button("Add prompt") | |
| del_btn = gr.Button("Delete prompt") | |
| add_btn.click(lambda x: x + 1, prompts_count, prompts_count) | |
| del_btn.click(lambda x: x - 1, prompts_count, prompts_count) | |
| def render_count(count): | |
| boxes = [] | |
| for i in range(count): | |
| with gr.Group(): | |
| prompt = gr.Textbox(key=str(i), label=f"Prompt {i+1}") | |
| boxes.append(prompt) | |
| generate_btn.click( | |
| process_input, | |
| [name, scale, triggers, tech_prompt, tech_negative_prompt, prompts_count]+boxes, | |
| output | |
| ) | |
| generate_btn = gr.Button("Generate!") | |
| output = gr.Gallery( | |
| label="Generation results:", | |
| object_fit="contain", | |
| height="auto", | |
| ) | |
| demo.launch() | |
| ############################################################################################################ |