Spaces:
Running on Zero
Running on Zero
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import PIL.Image | |
| from PIL import Image, PngImagePlugin | |
| import random | |
| from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, DDIMScheduler, UniPCMultistepScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler | |
| import torch | |
| from compel import Compel, ReturnedEmbeddingsType | |
| import requests | |
| import os | |
| import re | |
| import gc | |
| from huggingface_hub import hf_hub_download | |
| # This dummy function is required to pass the Hugging Face Spaces startup check for GPU apps. | |
| def dummy_gpu_for_startup(): | |
| print("Dummy function for startup check executed. This is normal.") | |
| return "Startup check passed." | |
| # --- Constants --- | |
| MAX_LORAS = 5 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MAX_SEED = np.iinfo(np.int64).max | |
| MAX_IMAGE_SIZE = 1216 | |
| SAMPLER_MAP = { | |
| "Euler a": EulerAncestralDiscreteScheduler, | |
| "Euler": EulerDiscreteScheduler, | |
| "DPM++ 2M Karras": DPMSolverMultistepScheduler, | |
| "DDIM": DDIMScheduler, | |
| "UniPC": UniPCMultistepScheduler, | |
| "Heun": HeunDiscreteScheduler, | |
| "LMS": LMSDiscreteScheduler, | |
| } | |
| SCHEDULE_TYPE_MAP = ["Default", "Karras", "Uniform", "SGM Uniform"] | |
| DEFAULT_SCHEDULE_TYPE = "Default" | |
| DEFAULT_SAMPLER = "Euler a" | |
| DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn," | |
| DOWNLOAD_DIR = "/tmp/loras" | |
| os.makedirs(DOWNLOAD_DIR, exist_ok=True) | |
| # --- Model Lists --- | |
| MODEL_LIST = [ | |
| "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl", | |
| "Laxhar/noobai-XL-Vpred-1.0", | |
| "John6666/hassaku-xl-illustrious-v30-sdxl", | |
| "RedRayz/hikari_noob_v-pred_1.2.2", | |
| "bluepen5805/noob_v_pencil-XL", | |
| "Laxhar/noobai-XL-1.1" | |
| ] | |
| # --- List of V-Prediction Models --- | |
| V_PREDICTION_MODELS = [ | |
| "Laxhar/noobai-XL-Vpred-1.0", | |
| "RedRayz/hikari_noob_v-pred_1.2.2", | |
| "bluepen5805/noob_v_pencil-XL" | |
| ] | |
| # --- Dictionary for single-file models now stores the filename --- | |
| SINGLE_FILE_MODELS = { | |
| "bluepen5805/noob_v_pencil-XL": "noob_v_pencil-XL-v3.0.0.safetensors" | |
| } | |
| # --- Model Hash to Name Mapping --- | |
| HASH_TO_MODEL_MAP = { | |
| "bdb59bac77": "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl", | |
| "ea349eeae8": "Laxhar/noobai-XL-Vpred-1.0", | |
| "b4fb5f829a": "John6666/hassaku-xl-illustrious-v30-sdxl", | |
| "6681e8e4b1": "Laxhar/noobai-XL-1.1", | |
| "90b7911a78": "bluepen5805/noob_v_pencil-XL", | |
| "874170688a": "RedRayz/hikari_noob_v-pred_1.2.2" | |
| } | |
| def get_civitai_file_info(version_id): | |
| """Gets the file metadata for a model version via the Civitai API.""" | |
| api_url = f"https://civitai.com/api/v1/model-versions/{version_id}" | |
| try: | |
| response = requests.get(api_url) | |
| response.raise_for_status() | |
| data = response.json() | |
| for file_data in data.get('files', []): | |
| if file_data['name'].endswith('.safetensors'): | |
| return file_data | |
| if data.get('files'): | |
| return data['files'][0] | |
| return None | |
| except Exception as e: | |
| print(f"Could not get file info from Civitai API: {e}") | |
| return None | |
| def download_file(url, save_path, api_key=None, progress=None, desc=""): | |
| """Downloads a file, skipping if it already exists.""" | |
| if os.path.exists(save_path): | |
| return f"File already exists: {os.path.basename(save_path)}" | |
| headers = {} | |
| if api_key and api_key.strip(): | |
| headers['Authorization'] = f'Bearer {api_key}' | |
| try: | |
| if progress: progress(0, desc=desc) | |
| response = requests.get(url, stream=True, headers=headers) | |
| response.raise_for_status() | |
| total_size = int(response.headers.get('content-length', 0)) | |
| with open(save_path, "wb") as f: | |
| downloaded = 0 | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| if progress and total_size > 0: | |
| downloaded += len(chunk) | |
| progress(downloaded / total_size, desc=desc) | |
| return f"Successfully downloaded: {os.path.basename(save_path)}" | |
| except Exception as e: | |
| if os.path.exists(save_path): os.remove(save_path) | |
| return f"Download failed for {os.path.basename(save_path)}: {e}" | |
| def process_long_prompt(compel_proc, prompt, negative_prompt=""): | |
| try: | |
| conditioning, pooled = compel_proc([prompt, negative_prompt]) | |
| return conditioning, pooled | |
| except Exception: | |
| return None, None | |
| def pre_download_base_model(model_name, progress=gr.Progress(track_tqdm=True)): | |
| if not model_name: | |
| return "Please select a base model to download." | |
| status_log = [] | |
| try: | |
| progress(0, desc=f"Starting download for: {model_name}") | |
| if model_name in SINGLE_FILE_MODELS: | |
| filename = SINGLE_FILE_MODELS[model_name] | |
| print(f"Pre-downloading single file: {filename} from repo: {model_name}") | |
| local_path = hf_hub_download(repo_id=model_name, filename=filename) | |
| pipe = StableDiffusionXLPipeline.from_single_file( | |
| local_path, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True | |
| ) | |
| else: | |
| print(f"Pre-downloading diffusers model: {model_name}") | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True | |
| ) | |
| status_log.append(f"✅ Successfully downloaded {model_name}") | |
| del pipe | |
| except Exception as e: | |
| status_log.append(f"❌ Failed to download {model_name}: {e}") | |
| finally: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return "\n".join(status_log) | |
| def pre_download_loras(civitai_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)): | |
| civitai_ids = lora_data[0::2] | |
| status_log = [] | |
| active_lora_ids = [cid for cid in civitai_ids if cid and cid.strip()] | |
| if not active_lora_ids: | |
| return "No LoRA IDs provided to download." | |
| for i, civitai_id in enumerate(active_lora_ids): | |
| version_id = civitai_id.strip() | |
| progress(i / len(active_lora_ids), desc=f"Getting URL for LoRA ID: {version_id}") | |
| local_lora_path = os.path.join(DOWNLOAD_DIR, f"civitai_{version_id}.safetensors") | |
| file_info = get_civitai_file_info(version_id) | |
| if not file_info: | |
| status_log.append(f"* LoRA ID {version_id}: Could not get file info from Civitai.") | |
| continue | |
| download_url = file_info.get('downloadUrl') | |
| if not download_url: | |
| status_log.append(f"* LoRA ID {version_id}: Could not get download link.") | |
| continue | |
| status = download_file( | |
| download_url, | |
| local_lora_path, | |
| api_key=civitai_api_key, | |
| progress=progress, | |
| desc=f"Downloading LoRA ID: {version_id}" | |
| ) | |
| status_log.append(f"* LoRA ID {version_id}: {status}") | |
| return "\n".join(status_log) | |
| def _infer_logic(base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps, | |
| sampler, schedule_type, | |
| civitai_api_key, | |
| *lora_data, | |
| progress=gr.Progress(track_tqdm=True)): | |
| pipe = None | |
| try: | |
| progress(0, desc=f"Loading model: {base_model_name}") | |
| if base_model_name in SINGLE_FILE_MODELS: | |
| filename = SINGLE_FILE_MODELS[base_model_name] | |
| print(f"Loading single file: {filename} from repo: {base_model_name}") | |
| local_path = hf_hub_download(repo_id=base_model_name, filename=filename) | |
| pipe = StableDiffusionXLPipeline.from_single_file( | |
| local_path, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True | |
| ) | |
| else: | |
| print(f"Loading diffusers model: {base_model_name}") | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| base_model_name, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True | |
| ) | |
| pipe.to(device) | |
| batch_size = int(batch_size) | |
| seed = int(seed) | |
| pipe.unload_lora_weights() | |
| scheduler_class = SAMPLER_MAP.get(sampler, EulerAncestralDiscreteScheduler) | |
| scheduler_config = pipe.scheduler.config | |
| if base_model_name in V_PREDICTION_MODELS: | |
| scheduler_config['prediction_type'] = 'v_prediction' | |
| else: | |
| scheduler_config['prediction_type'] = 'epsilon' | |
| scheduler_kwargs = {} | |
| if schedule_type == "Default" and sampler == "DPM++ 2M Karras": | |
| scheduler_kwargs['use_karras_sigmas'] = True | |
| elif schedule_type == "Karras": | |
| scheduler_kwargs['use_karras_sigmas'] = True | |
| elif schedule_type == "Uniform": | |
| scheduler_kwargs['use_karras_sigmas'] = False | |
| elif schedule_type == "SGM Uniform": | |
| scheduler_kwargs['algorithm_type'] = 'sgm_uniform' | |
| pipe.scheduler = scheduler_class.from_config(scheduler_config, **scheduler_kwargs) | |
| compel_type = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED | |
| compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2], text_encoder=[pipe.text_encoder, pipe.text_encoder_2], | |
| returned_embeddings_type=compel_type, requires_pooled=[False, True], truncate_long_prompts=False) | |
| civitai_ids, lora_scales = lora_data[0::2], lora_data[1::2] | |
| lora_params = list(zip(civitai_ids, lora_scales)) | |
| active_loras, active_lora_names_for_meta = [], [] | |
| for i, (civitai_id, lora_scale) in enumerate(lora_params): | |
| if civitai_id and civitai_id.strip() and lora_scale > 0: | |
| version_id = civitai_id.strip() | |
| local_lora_path = os.path.join(DOWNLOAD_DIR, f"civitai_{version_id}.safetensors") | |
| if not os.path.exists(local_lora_path): | |
| file_info = get_civitai_file_info(version_id) | |
| if not file_info: | |
| print(f"Could not get file info for Civitai ID {version_id}, skipping.") | |
| continue | |
| download_url = file_info.get('downloadUrl') | |
| if download_url: | |
| download_file(download_url, local_lora_path, api_key=civitai_api_key, progress=progress, desc=f"Downloading LoRA ID {version_id}") | |
| else: | |
| print(f"Could not get download link for Civitai ID {version_id} during inference, skipping."); continue | |
| if not os.path.exists(local_lora_path): print(f"LoRA file for ID {version_id} not found, skipping."); continue | |
| adapter_name = f"lora_{i+1}" | |
| progress((i * 0.1) + 0.05, desc=f"Loading LoRA (ID: {version_id})") | |
| pipe.load_lora_weights(local_lora_path, adapter_name=adapter_name) | |
| active_loras.append((adapter_name, lora_scale)) | |
| active_lora_names_for_meta.append(f"LoRA {i+1} (ID: {version_id}, Weight: {lora_scale})") | |
| if active_loras: | |
| adapter_names, adapter_weights = zip(*active_loras); pipe.set_adapters(list(adapter_names), list(adapter_weights)) | |
| conditioning, pooled = process_long_prompt(compel, prompt, negative_prompt) | |
| pipe_args = { | |
| "guidance_scale": guidance_scale, | |
| "num_inference_steps": num_inference_steps, | |
| "width": width, | |
| "height": height, | |
| } | |
| output_images = [] | |
| loras_string = f"LoRAs: [{', '.join(active_lora_names_for_meta)}]" if active_lora_names_for_meta else "" | |
| for i in range(batch_size): | |
| progress(i / batch_size, desc=f"Generating image {i+1}/{batch_size}") | |
| if i == 0 and seed != -1: | |
| current_seed = seed | |
| else: | |
| current_seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device=device).manual_seed(current_seed) | |
| pipe_args["generator"] = generator | |
| if conditioning is not None: | |
| image = pipe(prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2], **pipe_args).images[0] | |
| else: | |
| image = pipe(prompt=prompt, negative_prompt=negative_prompt, **pipe_args).images[0] | |
| params_string = f"{prompt}\nNegative prompt: {negative_prompt}\n" | |
| params_string += f"Steps: {num_inference_steps}, Sampler: {sampler}, Schedule type: {schedule_type}, CFG scale: {guidance_scale}, Seed: {current_seed}, Size: {width}x{height}, Base Model: {base_model_name}, {loras_string}".strip() | |
| image.info = {'parameters': params_string} | |
| output_images.append(image) | |
| return output_images | |
| except Exception as e: | |
| print(f"An error occurred during generation: {e}"); raise gr.Error(f"Generation failed: {e}") | |
| finally: | |
| if pipe is not None: | |
| pipe.disable_lora() | |
| del pipe | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def infer(base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps, | |
| sampler, schedule_type, | |
| civitai_api_key, | |
| zero_gpu_duration, | |
| *lora_data, | |
| progress=gr.Progress(track_tqdm=True)): | |
| duration = 60 | |
| if zero_gpu_duration and int(zero_gpu_duration) > 0: | |
| duration = int(zero_gpu_duration) | |
| print(f"Using ZeroGPU duration: {duration} seconds") | |
| decorated_infer_logic = spaces.GPU(duration=duration)(_infer_logic) | |
| return decorated_infer_logic( | |
| base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps, | |
| sampler, schedule_type, civitai_api_key, *lora_data, progress=progress | |
| ) | |
| def _parse_parameters(params_text): | |
| data = {'lora_ids': [''] * MAX_LORAS, 'lora_scales': [0.0] * MAX_LORAS} | |
| lines = params_text.strip().split('\n') | |
| data['prompt'] = lines[0] | |
| data['negative_prompt'] = lines[1].replace("Negative prompt:", "").strip() if len(lines) > 1 and lines[1].startswith("Negative prompt:") else "" | |
| params_line = lines[2] if len(lines) > 2 else "" | |
| def find_param(key, default, cast_type=str): | |
| match = re.search(fr"\b{key}: ([^,]+?)(,|$)", params_line) | |
| if match: | |
| try: | |
| return cast_type(match.group(1).strip()) | |
| except (ValueError, TypeError): | |
| return default | |
| return default | |
| data['steps'] = find_param("Steps", 28, int) | |
| data['sampler'] = find_param("Sampler", DEFAULT_SAMPLER) | |
| data['schedule_type'] = find_param("Schedule type", DEFAULT_SCHEDULE_TYPE) | |
| data['cfg_scale'] = find_param("CFG scale", 7.0, float) | |
| data['seed'] = find_param("Seed", -1, int) | |
| data['base_model'] = find_param("Base Model", MODEL_LIST[0]) | |
| data['model_hash'] = find_param("Model hash", None) | |
| size_match = re.search(r"Size: (\d+)x(\d+)", params_line); data['width'], data['height'] = (int(size_match.group(1)), int(size_match.group(2))) if size_match else (1024, 1024) | |
| if loras_match := re.search(r"LoRAs: \[(.+?)\]", params_line): | |
| for i, (lora_id, lora_scale) in enumerate(re.findall(r"ID: (\d+), Weight: ([\d.]+)", loras_match.group(1))): | |
| if i < MAX_LORAS: data['lora_ids'][i] = lora_id; data['lora_scales'][i] = float(lora_scale) | |
| return data | |
| def get_png_info(image): | |
| if image is None: return "", "", "Please upload an image first." | |
| params = image.info.get('parameters', None) | |
| if not params: return "", "", "No metadata found in the image." | |
| try: | |
| parsed_data = _parse_parameters(params) | |
| lines = params.strip().split('\n') | |
| other_params_text = lines[2] if len(lines) > 2 else "" | |
| other_params_display = "\n".join([p.strip() for p in other_params_text.split(',')]) | |
| return parsed_data.get('prompt', ''), parsed_data.get('negative_prompt', ''), other_params_display | |
| except Exception as e: | |
| return "", "", f"Error parsing metadata: {e}\n\nRaw metadata:\n{params}" | |
| def send_info_to_txt2img(image): | |
| if image is None or not (params := image.info.get('parameters', '')): | |
| return [gr.update()] * (12 + MAX_LORAS * 2 + 1) | |
| data = _parse_parameters(params) | |
| model_from_hash = HASH_TO_MODEL_MAP.get(data.get('model_hash')) | |
| final_base_model = model_from_hash if model_from_hash else data.get('base_model', MODEL_LIST[0]) | |
| sampler_from_png = data.get('sampler', DEFAULT_SAMPLER) | |
| final_sampler = sampler_from_png if sampler_from_png in SAMPLER_MAP else DEFAULT_SAMPLER | |
| schedule_from_png = data.get('schedule_type', DEFAULT_SCHEDULE_TYPE) | |
| final_schedule_type = schedule_from_png if schedule_from_png in SCHEDULE_TYPE_MAP else DEFAULT_SCHEDULE_TYPE | |
| updates = [final_base_model, data['prompt'], data['negative_prompt'], data['seed'], gr.update(), gr.update(), data['width'], data['height'], | |
| data['cfg_scale'], data['steps'], final_sampler, final_schedule_type] | |
| for i in range(MAX_LORAS): updates.extend([data['lora_ids'][i], data['lora_scales'][i]]) | |
| updates.append(gr.Tabs(selected=0)) | |
| return updates | |
| with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo: | |
| gr.Markdown("# Animated SDXL T2I with LoRAs") | |
| with gr.Tabs(elem_id="tabs_container") as tabs: | |
| with gr.TabItem("txt2img", id=0): | |
| gr.Markdown("<div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>💡 <b>Tip:</b> Pre-downloading the base model and LoRAs before clicking 'Run' can maximize your ZeroGPU time.</div>") | |
| with gr.Column(elem_id="col-container"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| base_model_name = gr.Dropdown(label="Base Model", choices=MODEL_LIST, value="Laxhar/noobai-XL-Vpred-1.0") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| predownload_base_model_button = gr.Button("Pre-download Base Model") | |
| predownload_lora_button = gr.Button("Pre-download LoRAs") | |
| with gr.Column(scale=1, min_width=100): | |
| run_button = gr.Button("Run", variant="primary") | |
| predownload_status = gr.Markdown("") | |
| prompt = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt") | |
| negative_prompt = gr.Text(label="Negative prompt", lines=3, placeholder="Enter a negative prompt", value=DEFAULT_NEGATIVE_PROMPT) | |
| # --- UI Layout --- | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
| height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
| with gr.Row(): | |
| sampler = gr.Dropdown(label="Sampling method", choices=list(SAMPLER_MAP.keys()), value=DEFAULT_SAMPLER) | |
| schedule_type = gr.Dropdown(label="Schedule type", choices=SCHEDULE_TYPE_MAP, value=DEFAULT_SCHEDULE_TYPE) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7) | |
| num_inference_steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=28) | |
| with gr.Column(scale=1): | |
| result = gr.Gallery(label="Result", show_label=False, elem_id="result_gallery", columns=2, object_fit="contain", height="auto") | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) | |
| batch_size = gr.Slider(label="Batch size", minimum=1, maximum=8, step=1, value=1) | |
| zero_gpu_duration = gr.Number( | |
| label="ZeroGPU Duration (s)", | |
| value=None, | |
| placeholder="Default: 60s", | |
| info="Optional: Leave empty for default (60s), max to 120" | |
| ) | |
| with gr.Accordion("LoRA Settings", open=False): | |
| gr.Markdown("⚠️ **Responsible Use Notice:** Please avoid excessive, rapid, or automated (scripted) use of the pre-download LoRA feature. Overt misuse may lead to service disruption. Thank you for your cooperation.") | |
| civitai_api_key = gr.Textbox(label="Optional Civitai API Key", info="Get from your Civitai account settings...", placeholder="Enter your Civitai API Key here", type="password", show_label=True) | |
| gr.Markdown("Find the Model Version ID in the LoRA page URL (e.g., `modelVersionId=12345`) and fill it in below.") | |
| lora_rows, lora_civitai_id_inputs, lora_scale_inputs = [], [], [] | |
| for i in range(MAX_LORAS): | |
| with gr.Row(visible=(i == 0)) as row: | |
| lora_civitai_id = gr.Textbox(label=f"LoRA {i+1} - Civitai Model Version ID", placeholder="e.g.: 1834914") | |
| lora_scale = gr.Slider(label=f"Weight {i+1}", minimum=0.0, maximum=2.0, step=0.05, value=0.0) | |
| lora_rows.append(row); lora_civitai_id_inputs.append(lora_civitai_id); lora_scale_inputs.append(lora_scale) | |
| with gr.Row(): | |
| add_lora_button = gr.Button("✚ Add LoRA", variant="secondary") | |
| lora_count_state = gr.State(value=1) | |
| all_lora_inputs = [item for pair in zip(lora_civitai_id_inputs, lora_scale_inputs) for item in pair] | |
| with gr.TabItem("PNG Info", id=1): | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("Upload a generated image to view its generation data.") | |
| info_image_input = gr.Image(type="pil", label="Upload Image") | |
| with gr.Row(): | |
| info_get_button = gr.Button("Get Info", variant="secondary") | |
| send_to_txt2img_button = gr.Button("Send to Txt-to-Image", variant="primary") | |
| gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False) | |
| gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False) | |
| gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False) | |
| gr.Markdown("<div style='text-align: center; margin-top: 20px;'>Made by RioShiina with ❤</div>") | |
| def add_lora_row(current_count): | |
| current_count = int(current_count) | |
| if current_count < MAX_LORAS: | |
| updates = {lora_count_state: current_count + 1, lora_rows[current_count]: gr.Row(visible=True)} | |
| if current_count + 1 == MAX_LORAS: updates[add_lora_button] = gr.Button(visible=False) | |
| return updates | |
| return {lora_count_state: current_count} | |
| add_lora_button.click(fn=add_lora_row, inputs=[lora_count_state], outputs=[lora_count_state, add_lora_button] + lora_rows) | |
| predownload_base_model_button.click(fn=pre_download_base_model, inputs=[base_model_name], outputs=[predownload_status]) | |
| predownload_lora_button.click(fn=pre_download_loras, inputs=[civitai_api_key, *all_lora_inputs], outputs=[predownload_status]) | |
| run_button.click(fn=infer, | |
| inputs=[base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps, sampler, schedule_type, civitai_api_key, zero_gpu_duration, *all_lora_inputs], | |
| outputs=[result]) | |
| info_get_button.click(fn=get_png_info, inputs=[info_image_input], outputs=[info_prompt_output, info_neg_prompt_output, info_params_output]) | |
| txt2img_outputs = [base_model_name, prompt, negative_prompt, seed, batch_size, zero_gpu_duration, width, height, guidance_scale, num_inference_steps, sampler, schedule_type, *all_lora_inputs, tabs] | |
| send_to_txt2img_button.click(fn=send_info_to_txt2img, inputs=[info_image_input], outputs=txt2img_outputs) | |
| demo.queue().launch() |