ImageGen-NoobAI / app.py
RioShiina's picture
Upload 3 files
992e30a verified
raw
history blame
25.6 kB
import spaces
import gradio as gr
import numpy as np
import PIL.Image
from PIL import Image, PngImagePlugin
import random
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, DDIMScheduler, UniPCMultistepScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler
import torch
from compel import Compel, ReturnedEmbeddingsType
import requests
import os
import re
import gc
from huggingface_hub import hf_hub_download
# This dummy function is required to pass the Hugging Face Spaces startup check for GPU apps.
@spaces.GPU(duration=60)
def dummy_gpu_for_startup():
print("Dummy function for startup check executed. This is normal.")
return "Startup check passed."
# --- Constants ---
MAX_LORAS = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_SEED = np.iinfo(np.int64).max
MAX_IMAGE_SIZE = 1216
SAMPLER_MAP = {
"Euler a": EulerAncestralDiscreteScheduler,
"Euler": EulerDiscreteScheduler,
"DPM++ 2M Karras": DPMSolverMultistepScheduler,
"DDIM": DDIMScheduler,
"UniPC": UniPCMultistepScheduler,
"Heun": HeunDiscreteScheduler,
"LMS": LMSDiscreteScheduler,
}
SCHEDULE_TYPE_MAP = ["Default", "Karras", "Uniform", "SGM Uniform"]
DEFAULT_SCHEDULE_TYPE = "Default"
DEFAULT_SAMPLER = "Euler a"
DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn,"
DOWNLOAD_DIR = "/tmp/loras"
os.makedirs(DOWNLOAD_DIR, exist_ok=True)
# --- Model Lists ---
MODEL_LIST = [
"dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
"Laxhar/noobai-XL-Vpred-1.0",
"John6666/hassaku-xl-illustrious-v30-sdxl",
"RedRayz/hikari_noob_v-pred_1.2.2",
"bluepen5805/noob_v_pencil-XL",
"Laxhar/noobai-XL-1.1"
]
# --- List of V-Prediction Models ---
V_PREDICTION_MODELS = [
"Laxhar/noobai-XL-Vpred-1.0",
"RedRayz/hikari_noob_v-pred_1.2.2",
"bluepen5805/noob_v_pencil-XL"
]
# --- Dictionary for single-file models now stores the filename ---
SINGLE_FILE_MODELS = {
"bluepen5805/noob_v_pencil-XL": "noob_v_pencil-XL-v3.0.0.safetensors"
}
# --- Model Hash to Name Mapping ---
HASH_TO_MODEL_MAP = {
"bdb59bac77": "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
"ea349eeae8": "Laxhar/noobai-XL-Vpred-1.0",
"b4fb5f829a": "John6666/hassaku-xl-illustrious-v30-sdxl",
"6681e8e4b1": "Laxhar/noobai-XL-1.1",
"90b7911a78": "bluepen5805/noob_v_pencil-XL",
"874170688a": "RedRayz/hikari_noob_v-pred_1.2.2"
}
def get_civitai_file_info(version_id):
"""Gets the file metadata for a model version via the Civitai API."""
api_url = f"https://civitai.com/api/v1/model-versions/{version_id}"
try:
response = requests.get(api_url)
response.raise_for_status()
data = response.json()
for file_data in data.get('files', []):
if file_data['name'].endswith('.safetensors'):
return file_data
if data.get('files'):
return data['files'][0]
return None
except Exception as e:
print(f"Could not get file info from Civitai API: {e}")
return None
def download_file(url, save_path, api_key=None, progress=None, desc=""):
"""Downloads a file, skipping if it already exists."""
if os.path.exists(save_path):
return f"File already exists: {os.path.basename(save_path)}"
headers = {}
if api_key and api_key.strip():
headers['Authorization'] = f'Bearer {api_key}'
try:
if progress: progress(0, desc=desc)
response = requests.get(url, stream=True, headers=headers)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(save_path, "wb") as f:
downloaded = 0
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
if progress and total_size > 0:
downloaded += len(chunk)
progress(downloaded / total_size, desc=desc)
return f"Successfully downloaded: {os.path.basename(save_path)}"
except Exception as e:
if os.path.exists(save_path): os.remove(save_path)
return f"Download failed for {os.path.basename(save_path)}: {e}"
def process_long_prompt(compel_proc, prompt, negative_prompt=""):
try:
conditioning, pooled = compel_proc([prompt, negative_prompt])
return conditioning, pooled
except Exception:
return None, None
def pre_download_base_model(model_name, progress=gr.Progress(track_tqdm=True)):
if not model_name:
return "Please select a base model to download."
status_log = []
try:
progress(0, desc=f"Starting download for: {model_name}")
if model_name in SINGLE_FILE_MODELS:
filename = SINGLE_FILE_MODELS[model_name]
print(f"Pre-downloading single file: {filename} from repo: {model_name}")
local_path = hf_hub_download(repo_id=model_name, filename=filename)
pipe = StableDiffusionXLPipeline.from_single_file(
local_path,
torch_dtype=torch.float16,
use_safetensors=True
)
else:
print(f"Pre-downloading diffusers model: {model_name}")
pipe = StableDiffusionXLPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_safetensors=True
)
status_log.append(f"✅ Successfully downloaded {model_name}")
del pipe
except Exception as e:
status_log.append(f"❌ Failed to download {model_name}: {e}")
finally:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return "\n".join(status_log)
def pre_download_loras(civitai_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)):
civitai_ids = lora_data[0::2]
status_log = []
active_lora_ids = [cid for cid in civitai_ids if cid and cid.strip()]
if not active_lora_ids:
return "No LoRA IDs provided to download."
for i, civitai_id in enumerate(active_lora_ids):
version_id = civitai_id.strip()
progress(i / len(active_lora_ids), desc=f"Getting URL for LoRA ID: {version_id}")
local_lora_path = os.path.join(DOWNLOAD_DIR, f"civitai_{version_id}.safetensors")
file_info = get_civitai_file_info(version_id)
if not file_info:
status_log.append(f"* LoRA ID {version_id}: Could not get file info from Civitai.")
continue
download_url = file_info.get('downloadUrl')
if not download_url:
status_log.append(f"* LoRA ID {version_id}: Could not get download link.")
continue
status = download_file(
download_url,
local_lora_path,
api_key=civitai_api_key,
progress=progress,
desc=f"Downloading LoRA ID: {version_id}"
)
status_log.append(f"* LoRA ID {version_id}: {status}")
return "\n".join(status_log)
def _infer_logic(base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
sampler, schedule_type,
civitai_api_key,
*lora_data,
progress=gr.Progress(track_tqdm=True)):
pipe = None
try:
progress(0, desc=f"Loading model: {base_model_name}")
if base_model_name in SINGLE_FILE_MODELS:
filename = SINGLE_FILE_MODELS[base_model_name]
print(f"Loading single file: {filename} from repo: {base_model_name}")
local_path = hf_hub_download(repo_id=base_model_name, filename=filename)
pipe = StableDiffusionXLPipeline.from_single_file(
local_path,
torch_dtype=torch.float16,
use_safetensors=True
)
else:
print(f"Loading diffusers model: {base_model_name}")
pipe = StableDiffusionXLPipeline.from_pretrained(
base_model_name,
torch_dtype=torch.float16,
use_safetensors=True
)
pipe.to(device)
batch_size = int(batch_size)
seed = int(seed)
pipe.unload_lora_weights()
scheduler_class = SAMPLER_MAP.get(sampler, EulerAncestralDiscreteScheduler)
scheduler_config = pipe.scheduler.config
if base_model_name in V_PREDICTION_MODELS:
scheduler_config['prediction_type'] = 'v_prediction'
else:
scheduler_config['prediction_type'] = 'epsilon'
scheduler_kwargs = {}
if schedule_type == "Default" and sampler == "DPM++ 2M Karras":
scheduler_kwargs['use_karras_sigmas'] = True
elif schedule_type == "Karras":
scheduler_kwargs['use_karras_sigmas'] = True
elif schedule_type == "Uniform":
scheduler_kwargs['use_karras_sigmas'] = False
elif schedule_type == "SGM Uniform":
scheduler_kwargs['algorithm_type'] = 'sgm_uniform'
pipe.scheduler = scheduler_class.from_config(scheduler_config, **scheduler_kwargs)
compel_type = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2], text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=compel_type, requires_pooled=[False, True], truncate_long_prompts=False)
civitai_ids, lora_scales = lora_data[0::2], lora_data[1::2]
lora_params = list(zip(civitai_ids, lora_scales))
active_loras, active_lora_names_for_meta = [], []
for i, (civitai_id, lora_scale) in enumerate(lora_params):
if civitai_id and civitai_id.strip() and lora_scale > 0:
version_id = civitai_id.strip()
local_lora_path = os.path.join(DOWNLOAD_DIR, f"civitai_{version_id}.safetensors")
if not os.path.exists(local_lora_path):
file_info = get_civitai_file_info(version_id)
if not file_info:
print(f"Could not get file info for Civitai ID {version_id}, skipping.")
continue
download_url = file_info.get('downloadUrl')
if download_url:
download_file(download_url, local_lora_path, api_key=civitai_api_key, progress=progress, desc=f"Downloading LoRA ID {version_id}")
else:
print(f"Could not get download link for Civitai ID {version_id} during inference, skipping."); continue
if not os.path.exists(local_lora_path): print(f"LoRA file for ID {version_id} not found, skipping."); continue
adapter_name = f"lora_{i+1}"
progress((i * 0.1) + 0.05, desc=f"Loading LoRA (ID: {version_id})")
pipe.load_lora_weights(local_lora_path, adapter_name=adapter_name)
active_loras.append((adapter_name, lora_scale))
active_lora_names_for_meta.append(f"LoRA {i+1} (ID: {version_id}, Weight: {lora_scale})")
if active_loras:
adapter_names, adapter_weights = zip(*active_loras); pipe.set_adapters(list(adapter_names), list(adapter_weights))
conditioning, pooled = process_long_prompt(compel, prompt, negative_prompt)
pipe_args = {
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"width": width,
"height": height,
}
output_images = []
loras_string = f"LoRAs: [{', '.join(active_lora_names_for_meta)}]" if active_lora_names_for_meta else ""
for i in range(batch_size):
progress(i / batch_size, desc=f"Generating image {i+1}/{batch_size}")
if i == 0 and seed != -1:
current_seed = seed
else:
current_seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(current_seed)
pipe_args["generator"] = generator
if conditioning is not None:
image = pipe(prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2], **pipe_args).images[0]
else:
image = pipe(prompt=prompt, negative_prompt=negative_prompt, **pipe_args).images[0]
params_string = f"{prompt}\nNegative prompt: {negative_prompt}\n"
params_string += f"Steps: {num_inference_steps}, Sampler: {sampler}, Schedule type: {schedule_type}, CFG scale: {guidance_scale}, Seed: {current_seed}, Size: {width}x{height}, Base Model: {base_model_name}, {loras_string}".strip()
image.info = {'parameters': params_string}
output_images.append(image)
return output_images
except Exception as e:
print(f"An error occurred during generation: {e}"); raise gr.Error(f"Generation failed: {e}")
finally:
if pipe is not None:
pipe.disable_lora()
del pipe
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def infer(base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
sampler, schedule_type,
civitai_api_key,
zero_gpu_duration,
*lora_data,
progress=gr.Progress(track_tqdm=True)):
duration = 60
if zero_gpu_duration and int(zero_gpu_duration) > 0:
duration = int(zero_gpu_duration)
print(f"Using ZeroGPU duration: {duration} seconds")
decorated_infer_logic = spaces.GPU(duration=duration)(_infer_logic)
return decorated_infer_logic(
base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
sampler, schedule_type, civitai_api_key, *lora_data, progress=progress
)
def _parse_parameters(params_text):
data = {'lora_ids': [''] * MAX_LORAS, 'lora_scales': [0.0] * MAX_LORAS}
lines = params_text.strip().split('\n')
data['prompt'] = lines[0]
data['negative_prompt'] = lines[1].replace("Negative prompt:", "").strip() if len(lines) > 1 and lines[1].startswith("Negative prompt:") else ""
params_line = lines[2] if len(lines) > 2 else ""
def find_param(key, default, cast_type=str):
match = re.search(fr"\b{key}: ([^,]+?)(,|$)", params_line)
if match:
try:
return cast_type(match.group(1).strip())
except (ValueError, TypeError):
return default
return default
data['steps'] = find_param("Steps", 28, int)
data['sampler'] = find_param("Sampler", DEFAULT_SAMPLER)
data['schedule_type'] = find_param("Schedule type", DEFAULT_SCHEDULE_TYPE)
data['cfg_scale'] = find_param("CFG scale", 7.0, float)
data['seed'] = find_param("Seed", -1, int)
data['base_model'] = find_param("Base Model", MODEL_LIST[0])
data['model_hash'] = find_param("Model hash", None)
size_match = re.search(r"Size: (\d+)x(\d+)", params_line); data['width'], data['height'] = (int(size_match.group(1)), int(size_match.group(2))) if size_match else (1024, 1024)
if loras_match := re.search(r"LoRAs: \[(.+?)\]", params_line):
for i, (lora_id, lora_scale) in enumerate(re.findall(r"ID: (\d+), Weight: ([\d.]+)", loras_match.group(1))):
if i < MAX_LORAS: data['lora_ids'][i] = lora_id; data['lora_scales'][i] = float(lora_scale)
return data
def get_png_info(image):
if image is None: return "", "", "Please upload an image first."
params = image.info.get('parameters', None)
if not params: return "", "", "No metadata found in the image."
try:
parsed_data = _parse_parameters(params)
lines = params.strip().split('\n')
other_params_text = lines[2] if len(lines) > 2 else ""
other_params_display = "\n".join([p.strip() for p in other_params_text.split(',')])
return parsed_data.get('prompt', ''), parsed_data.get('negative_prompt', ''), other_params_display
except Exception as e:
return "", "", f"Error parsing metadata: {e}\n\nRaw metadata:\n{params}"
def send_info_to_txt2img(image):
if image is None or not (params := image.info.get('parameters', '')):
return [gr.update()] * (12 + MAX_LORAS * 2 + 1)
data = _parse_parameters(params)
model_from_hash = HASH_TO_MODEL_MAP.get(data.get('model_hash'))
final_base_model = model_from_hash if model_from_hash else data.get('base_model', MODEL_LIST[0])
sampler_from_png = data.get('sampler', DEFAULT_SAMPLER)
final_sampler = sampler_from_png if sampler_from_png in SAMPLER_MAP else DEFAULT_SAMPLER
schedule_from_png = data.get('schedule_type', DEFAULT_SCHEDULE_TYPE)
final_schedule_type = schedule_from_png if schedule_from_png in SCHEDULE_TYPE_MAP else DEFAULT_SCHEDULE_TYPE
updates = [final_base_model, data['prompt'], data['negative_prompt'], data['seed'], gr.update(), gr.update(), data['width'], data['height'],
data['cfg_scale'], data['steps'], final_sampler, final_schedule_type]
for i in range(MAX_LORAS): updates.extend([data['lora_ids'][i], data['lora_scales'][i]])
updates.append(gr.Tabs(selected=0))
return updates
with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo:
gr.Markdown("# Animated SDXL T2I with LoRAs")
with gr.Tabs(elem_id="tabs_container") as tabs:
with gr.TabItem("txt2img", id=0):
gr.Markdown("<div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>💡 <b>Tip:</b> Pre-downloading the base model and LoRAs before clicking 'Run' can maximize your ZeroGPU time.</div>")
with gr.Column(elem_id="col-container"):
with gr.Row():
with gr.Column(scale=3):
base_model_name = gr.Dropdown(label="Base Model", choices=MODEL_LIST, value="Laxhar/noobai-XL-Vpred-1.0")
with gr.Column(scale=2):
with gr.Row():
predownload_base_model_button = gr.Button("Pre-download Base Model")
predownload_lora_button = gr.Button("Pre-download LoRAs")
with gr.Column(scale=1, min_width=100):
run_button = gr.Button("Run", variant="primary")
predownload_status = gr.Markdown("")
prompt = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt")
negative_prompt = gr.Text(label="Negative prompt", lines=3, placeholder="Enter a negative prompt", value=DEFAULT_NEGATIVE_PROMPT)
# --- UI Layout ---
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
with gr.Row():
sampler = gr.Dropdown(label="Sampling method", choices=list(SAMPLER_MAP.keys()), value=DEFAULT_SAMPLER)
schedule_type = gr.Dropdown(label="Schedule type", choices=SCHEDULE_TYPE_MAP, value=DEFAULT_SCHEDULE_TYPE)
with gr.Row():
guidance_scale = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7)
num_inference_steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=28)
with gr.Column(scale=1):
result = gr.Gallery(label="Result", show_label=False, elem_id="result_gallery", columns=2, object_fit="contain", height="auto")
with gr.Row():
seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
batch_size = gr.Slider(label="Batch size", minimum=1, maximum=8, step=1, value=1)
zero_gpu_duration = gr.Number(
label="ZeroGPU Duration (s)",
value=None,
placeholder="Default: 60s",
info="Optional: Leave empty for default (60s), max to 120"
)
with gr.Accordion("LoRA Settings", open=False):
gr.Markdown("⚠️ **Responsible Use Notice:** Please avoid excessive, rapid, or automated (scripted) use of the pre-download LoRA feature. Overt misuse may lead to service disruption. Thank you for your cooperation.")
civitai_api_key = gr.Textbox(label="Optional Civitai API Key", info="Get from your Civitai account settings...", placeholder="Enter your Civitai API Key here", type="password", show_label=True)
gr.Markdown("Find the Model Version ID in the LoRA page URL (e.g., `modelVersionId=12345`) and fill it in below.")
lora_rows, lora_civitai_id_inputs, lora_scale_inputs = [], [], []
for i in range(MAX_LORAS):
with gr.Row(visible=(i == 0)) as row:
lora_civitai_id = gr.Textbox(label=f"LoRA {i+1} - Civitai Model Version ID", placeholder="e.g.: 1834914")
lora_scale = gr.Slider(label=f"Weight {i+1}", minimum=0.0, maximum=2.0, step=0.05, value=0.0)
lora_rows.append(row); lora_civitai_id_inputs.append(lora_civitai_id); lora_scale_inputs.append(lora_scale)
with gr.Row():
add_lora_button = gr.Button("✚ Add LoRA", variant="secondary")
lora_count_state = gr.State(value=1)
all_lora_inputs = [item for pair in zip(lora_civitai_id_inputs, lora_scale_inputs) for item in pair]
with gr.TabItem("PNG Info", id=1):
with gr.Column(elem_id="col-container"):
gr.Markdown("Upload a generated image to view its generation data.")
info_image_input = gr.Image(type="pil", label="Upload Image")
with gr.Row():
info_get_button = gr.Button("Get Info", variant="secondary")
send_to_txt2img_button = gr.Button("Send to Txt-to-Image", variant="primary")
gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False)
gr.Markdown("<div style='text-align: center; margin-top: 20px;'>Made by RioShiina with ❤</div>")
def add_lora_row(current_count):
current_count = int(current_count)
if current_count < MAX_LORAS:
updates = {lora_count_state: current_count + 1, lora_rows[current_count]: gr.Row(visible=True)}
if current_count + 1 == MAX_LORAS: updates[add_lora_button] = gr.Button(visible=False)
return updates
return {lora_count_state: current_count}
add_lora_button.click(fn=add_lora_row, inputs=[lora_count_state], outputs=[lora_count_state, add_lora_button] + lora_rows)
predownload_base_model_button.click(fn=pre_download_base_model, inputs=[base_model_name], outputs=[predownload_status])
predownload_lora_button.click(fn=pre_download_loras, inputs=[civitai_api_key, *all_lora_inputs], outputs=[predownload_status])
run_button.click(fn=infer,
inputs=[base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps, sampler, schedule_type, civitai_api_key, zero_gpu_duration, *all_lora_inputs],
outputs=[result])
info_get_button.click(fn=get_png_info, inputs=[info_image_input], outputs=[info_prompt_output, info_neg_prompt_output, info_params_output])
txt2img_outputs = [base_model_name, prompt, negative_prompt, seed, batch_size, zero_gpu_duration, width, height, guidance_scale, num_inference_steps, sampler, schedule_type, *all_lora_inputs, tabs]
send_to_txt2img_button.click(fn=send_info_to_txt2img, inputs=[info_image_input], outputs=txt2img_outputs)
demo.queue().launch()