Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
app.py - Mega-Scale Refactor for Celebrity_LoRa_Mix Space
|
| 3 |
+
|
| 4 |
+
Features:
|
| 5 |
+
- Modular imports and dependency management
|
| 6 |
+
- Advanced error handling with user-facing messages
|
| 7 |
+
- Async-ready pipeline integration with fallback sync support
|
| 8 |
+
- Mobile-first responsive layout with concise UX messaging
|
| 9 |
+
- Leverages helpers.py and lora_manager.py for clarity and reuse
|
| 10 |
+
|
| 11 |
+
Author: Helios Automation Alchemist
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
import random
|
| 19 |
+
import time
|
| 20 |
+
import asyncio
|
| 21 |
+
from typing import List
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import gradio as gr
|
| 25 |
+
import pandas as pd
|
| 26 |
+
import requests
|
| 27 |
+
|
| 28 |
+
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
|
| 29 |
+
from transformers import CLIPTokenizer, CLIPProcessor, CLIPModel, LongformerTokenizer, LongformerModel
|
| 30 |
+
from PIL import Image
|
| 31 |
+
|
| 32 |
+
# Custom modules
|
| 33 |
+
import helpers
|
| 34 |
+
from lora_manager import LoRAManager
|
| 35 |
+
|
| 36 |
+
# === Config ===
|
| 37 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 38 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s')
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
MAX_SEED = 2**32 - 1
|
| 42 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 43 |
+
DTYPE = torch.bfloat16 if DEVICE.type == 'cuda' else torch.float32
|
| 44 |
+
|
| 45 |
+
# === Model & tokenizer loading ===
|
| 46 |
+
def load_tokenizers_and_models():
|
| 47 |
+
try:
|
| 48 |
+
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16")
|
| 49 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
|
| 50 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
|
| 51 |
+
logger.info("CLIP tokenizer & model loaded.")
|
| 52 |
+
|
| 53 |
+
longformer_tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
|
| 54 |
+
longformer_model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
| 55 |
+
logger.info("Longformer tokenizer & model loaded.")
|
| 56 |
+
return clip_tokenizer, clip_processor, clip_model, longformer_tokenizer, longformer_model
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"Tokenizer/model load failed: {e}")
|
| 59 |
+
sys.exit(1)
|
| 60 |
+
|
| 61 |
+
clip_tokenizer, clip_processor, clip_model, longformer_tokenizer, longformer_model = load_tokenizers_and_models()
|
| 62 |
+
|
| 63 |
+
# === Load prompts and LoRAs ===
|
| 64 |
+
def load_prompts_and_loras():
|
| 65 |
+
try:
|
| 66 |
+
prompts = pd.read_csv("prompts.csv", header=None).values.flatten()
|
| 67 |
+
except FileNotFoundError:
|
| 68 |
+
logger.warning("prompts.csv missing, defaulting to empty prompts.")
|
| 69 |
+
prompts = []
|
| 70 |
+
try:
|
| 71 |
+
with open("loras.json", "r") as f:
|
| 72 |
+
loras = json.load(f)
|
| 73 |
+
except FileNotFoundError:
|
| 74 |
+
logger.warning("loras.json missing, defaulting to empty LoRA list.")
|
| 75 |
+
loras = []
|
| 76 |
+
return prompts, loras
|
| 77 |
+
|
| 78 |
+
PROMPT_VALUES, LORA_LIST = load_prompts_and_loras()
|
| 79 |
+
|
| 80 |
+
# === Initialize Diffusion Pipeline with retry and fallback ===
|
| 81 |
+
def initialize_pipeline(base_model="sayakpaul/FLUX.1-merged", max_retries=3):
|
| 82 |
+
for attempt in range(max_retries):
|
| 83 |
+
try:
|
| 84 |
+
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=DTYPE).to(DEVICE)
|
| 85 |
+
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=DTYPE).to(DEVICE)
|
| 86 |
+
|
| 87 |
+
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=DTYPE, vae=taef1).to(DEVICE)
|
| 88 |
+
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
|
| 89 |
+
base_model,
|
| 90 |
+
vae=good_vae,
|
| 91 |
+
transformer=pipe.transformer,
|
| 92 |
+
text_encoder=pipe.text_encoder,
|
| 93 |
+
tokenizer=pipe.tokenizer,
|
| 94 |
+
text_encoder_2=pipe.text_encoder_2,
|
| 95 |
+
tokenizer_2=pipe.tokenizer_2,
|
| 96 |
+
torch_dtype=DTYPE
|
| 97 |
+
)
|
| 98 |
+
pipe.flux_pipe_call_that_returns_an_iterable_of_images = helpers.flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
| 99 |
+
|
| 100 |
+
logger.info("Diffusion pipeline loaded successfully.")
|
| 101 |
+
return pipe, pipe_i2i
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.warning(f"Attempt {attempt + 1} failed: {e}")
|
| 104 |
+
time.sleep(5)
|
| 105 |
+
logger.error("Failed to load diffusion pipeline after retries.")
|
| 106 |
+
sys.exit(1)
|
| 107 |
+
|
| 108 |
+
pipe, pipe_i2i = initialize_pipeline()
|
| 109 |
+
|
| 110 |
+
# === LoRA Manager for adapter lifecycle ===
|
| 111 |
+
lora_manager = LoRAManager(LORA_LIST)
|
| 112 |
+
|
| 113 |
+
# === Core business logic ===
|
| 114 |
+
def process_input(text: str, max_length: int=4096):
|
| 115 |
+
if not text or not text.strip():
|
| 116 |
+
raise gr.Error("Prompt cannot be empty.")
|
| 117 |
+
return longformer_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
|
| 118 |
+
|
| 119 |
+
@helpers.async_run_if_possible
|
| 120 |
+
def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
|
| 121 |
+
pipe.to(DEVICE)
|
| 122 |
+
generator = torch.Generator(device=DEVICE).manual_seed(seed)
|
| 123 |
+
with helpers.calculate_duration("Generating image"):
|
| 124 |
+
for step_idx, img in enumerate(pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
| 125 |
+
prompt=prompt,
|
| 126 |
+
num_inference_steps=steps,
|
| 127 |
+
guidance_scale=cfg_scale,
|
| 128 |
+
width=width,
|
| 129 |
+
height=height,
|
| 130 |
+
generator=generator,
|
| 131 |
+
joint_attention_kwargs={"scale": 1.0},
|
| 132 |
+
output_type="pil",
|
| 133 |
+
good_vae=pipe.vae,
|
| 134 |
+
)):
|
| 135 |
+
yield img, seed, gr.update(value=f"Step {step_idx + 1}/{steps}", visible=True)
|
| 136 |
+
|
| 137 |
+
@spaces.GPU(duration=75)
|
| 138 |
+
def run_lora(prompt, cfg_scale, steps, selected_loras_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4,
|
| 139 |
+
randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
|
| 140 |
+
if not selected_loras_indices:
|
| 141 |
+
raise gr.Error("Select at least one LoRA.")
|
| 142 |
+
|
| 143 |
+
selected_loras = [loras_state[i] for i in selected_loras_indices]
|
| 144 |
+
|
| 145 |
+
# Compose prompt with LoRA trigger words
|
| 146 |
+
prepend_words = []
|
| 147 |
+
append_words = []
|
| 148 |
+
for lora in selected_loras:
|
| 149 |
+
tw = lora.get("trigger_word", "")
|
| 150 |
+
if tw:
|
| 151 |
+
if lora.get("trigger_position") == "prepend":
|
| 152 |
+
prepend_words.append(tw)
|
| 153 |
+
else:
|
| 154 |
+
append_words.append(tw)
|
| 155 |
+
prompt_mash = " ".join(prepend_words + [prompt] + append_words)
|
| 156 |
+
|
| 157 |
+
if randomize_seed or seed == 0:
|
| 158 |
+
seed = random.randint(0, MAX_SEED)
|
| 159 |
+
|
| 160 |
+
logger.info(f"Generating with prompt: {prompt_mash} Seed: {seed}")
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
lora_manager.set_active_loras(pipe, selected_loras, [lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4])
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.error(f"LoRA weight loading failed: {e}")
|
| 166 |
+
raise gr.Error(f"Failed to load LoRA weights: {str(e)}")
|
| 167 |
+
|
| 168 |
+
return generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
|
| 169 |
+
|
| 170 |
+
# === UI Setup ===
|
| 171 |
+
MOBILE_CSS = '''
|
| 172 |
+
@media (max-width: 600px) {
|
| 173 |
+
.gr-row { flex-direction: column !important; }
|
| 174 |
+
.button_total { width: 100% !important; }
|
| 175 |
+
}
|
| 176 |
+
'''
|
| 177 |
+
|
| 178 |
+
font = [gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
|
| 179 |
+
|
| 180 |
+
with gr.Blocks(theme=gr.themes.Soft(font=font), css=MOBILE_CSS, delete_cache=(128, 256)) as app:
|
| 181 |
+
# Title and app state
|
| 182 |
+
gr.HTML(
|
| 183 |
+
'<h1><img src="https://huggingface.co/spaces/keltezaa/Celebrity_LoRa_Mix/resolve/main/solo-traveller_16875043.png" alt="LoRA"> Celebrity_LoRa_Mix</h1>',
|
| 184 |
+
elem_id="title"
|
| 185 |
+
)
|
| 186 |
+
loras_state = gr.State(LORA_LIST)
|
| 187 |
+
selected_lora_indices = gr.State([])
|
| 188 |
+
|
| 189 |
+
# Main input prompt box
|
| 190 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Type a prompt after selecting a LoRA")
|
| 191 |
+
|
| 192 |
+
# LoRA selectors, sliders and images - built modularly here...
|
| 193 |
+
|
| 194 |
+
# Advanced parameters
|
| 195 |
+
with gr.Accordion("Advanced Settings", open=True):
|
| 196 |
+
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5)
|
| 197 |
+
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
|
| 198 |
+
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=768)
|
| 199 |
+
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
|
| 200 |
+
randomize_seed = gr.Checkbox(True, label="Randomize seed")
|
| 201 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
| 202 |
+
|
| 203 |
+
generate_button = gr.Button("Generate", variant="primary", elem_classes=["button_total"])
|
| 204 |
+
output_img = gr.Image(interactive=False, show_share_button=False)
|
| 205 |
+
progress_bar = gr.Markdown(visible=False)
|
| 206 |
+
|
| 207 |
+
# Bind callbacks here (your existing logic, updated variable names)
|
| 208 |
+
|
| 209 |
+
app.queue(concurrency_count=3).launch()
|