Spaces:
Runtime error
Runtime error
| """ | |
| app.py - Mega-Scale Refactor for Celebrity_LoRa_Mix Space | |
| Features: | |
| - Modular imports and dependency management | |
| - Advanced error handling with user-facing messages | |
| - Async-ready pipeline integration with fallback sync support | |
| - Mobile-first responsive layout with concise UX messaging | |
| - Leverages helpers.py and lora_manager.py for clarity and reuse | |
| Author: Helios Automation Alchemist | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import logging | |
| import random | |
| import time | |
| import asyncio | |
| from typing import List | |
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| import requests | |
| from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image | |
| from transformers import CLIPTokenizer, CLIPProcessor, CLIPModel, LongformerTokenizer, LongformerModel | |
| from PIL import Image | |
| # Custom modules | |
| import helpers | |
| from lora_manager import LoRAManager | |
| # === Config === | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s') | |
| logger = logging.getLogger(__name__) | |
| MAX_SEED = 2**32 - 1 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| DTYPE = torch.bfloat16 if DEVICE.type == 'cuda' else torch.float32 | |
| # === Model & tokenizer loading === | |
| def load_tokenizers_and_models(): | |
| try: | |
| clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16") | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") | |
| logger.info("CLIP tokenizer & model loaded.") | |
| longformer_tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096") | |
| longformer_model = LongformerModel.from_pretrained("allenai/longformer-base-4096") | |
| logger.info("Longformer tokenizer & model loaded.") | |
| return clip_tokenizer, clip_processor, clip_model, longformer_tokenizer, longformer_model | |
| except Exception as e: | |
| logger.error(f"Tokenizer/model load failed: {e}") | |
| sys.exit(1) | |
| clip_tokenizer, clip_processor, clip_model, longformer_tokenizer, longformer_model = load_tokenizers_and_models() | |
| # === Load prompts and LoRAs === | |
| def load_prompts_and_loras(): | |
| try: | |
| prompts = pd.read_csv("prompts.csv", header=None).values.flatten() | |
| except FileNotFoundError: | |
| logger.warning("prompts.csv missing, defaulting to empty prompts.") | |
| prompts = [] | |
| try: | |
| with open("loras.json", "r") as f: | |
| loras = json.load(f) | |
| except FileNotFoundError: | |
| logger.warning("loras.json missing, defaulting to empty LoRA list.") | |
| loras = [] | |
| return prompts, loras | |
| PROMPT_VALUES, LORA_LIST = load_prompts_and_loras() | |
| # === Initialize Diffusion Pipeline with retry and fallback === | |
| def initialize_pipeline(base_model="sayakpaul/FLUX.1-merged", max_retries=3): | |
| for attempt in range(max_retries): | |
| try: | |
| taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=DTYPE).to(DEVICE) | |
| good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=DTYPE).to(DEVICE) | |
| pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=DTYPE, vae=taef1).to(DEVICE) | |
| pipe_i2i = AutoPipelineForImage2Image.from_pretrained( | |
| base_model, | |
| vae=good_vae, | |
| transformer=pipe.transformer, | |
| text_encoder=pipe.text_encoder, | |
| tokenizer=pipe.tokenizer, | |
| text_encoder_2=pipe.text_encoder_2, | |
| tokenizer_2=pipe.tokenizer_2, | |
| torch_dtype=DTYPE | |
| ) | |
| pipe.flux_pipe_call_that_returns_an_iterable_of_images = helpers.flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) | |
| logger.info("Diffusion pipeline loaded successfully.") | |
| return pipe, pipe_i2i | |
| except Exception as e: | |
| logger.warning(f"Attempt {attempt + 1} failed: {e}") | |
| time.sleep(5) | |
| logger.error("Failed to load diffusion pipeline after retries.") | |
| sys.exit(1) | |
| pipe, pipe_i2i = initialize_pipeline() | |
| # === LoRA Manager for adapter lifecycle === | |
| lora_manager = LoRAManager(LORA_LIST) | |
| # === Core business logic === | |
| def process_input(text: str, max_length: int=4096): | |
| if not text or not text.strip(): | |
| raise gr.Error("Prompt cannot be empty.") | |
| return longformer_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=max_length) | |
| def generate_image(prompt, steps, seed, cfg_scale, width, height, progress): | |
| pipe.to(DEVICE) | |
| generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
| with helpers.calculate_duration("Generating image"): | |
| for step_idx, img in enumerate(pipe.flux_pipe_call_that_returns_an_iterable_of_images( | |
| prompt=prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=cfg_scale, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| joint_attention_kwargs={"scale": 1.0}, | |
| output_type="pil", | |
| good_vae=pipe.vae, | |
| )): | |
| yield img, seed, gr.update(value=f"Step {step_idx + 1}/{steps}", visible=True) | |
| def run_lora(prompt, cfg_scale, steps, selected_loras_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, | |
| randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)): | |
| if not selected_loras_indices: | |
| raise gr.Error("Select at least one LoRA.") | |
| selected_loras = [loras_state[i] for i in selected_loras_indices] | |
| # Compose prompt with LoRA trigger words | |
| prepend_words = [] | |
| append_words = [] | |
| for lora in selected_loras: | |
| tw = lora.get("trigger_word", "") | |
| if tw: | |
| if lora.get("trigger_position") == "prepend": | |
| prepend_words.append(tw) | |
| else: | |
| append_words.append(tw) | |
| prompt_mash = " ".join(prepend_words + [prompt] + append_words) | |
| if randomize_seed or seed == 0: | |
| seed = random.randint(0, MAX_SEED) | |
| logger.info(f"Generating with prompt: {prompt_mash} Seed: {seed}") | |
| try: | |
| lora_manager.set_active_loras(pipe, selected_loras, [lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4]) | |
| except Exception as e: | |
| logger.error(f"LoRA weight loading failed: {e}") | |
| raise gr.Error(f"Failed to load LoRA weights: {str(e)}") | |
| return generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress) | |
| # === UI Setup === | |
| MOBILE_CSS = ''' | |
| @media (max-width: 600px) { | |
| .gr-row { flex-direction: column !important; } | |
| .button_total { width: 100% !important; } | |
| } | |
| ''' | |
| font = [gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"] | |
| with gr.Blocks(theme=gr.themes.Soft(font=font), css=MOBILE_CSS, delete_cache=(128, 256)) as app: | |
| # Title and app state | |
| gr.HTML( | |
| '<h1><img src="https://huggingface.co/spaces/keltezaa/Celebrity_LoRa_Mix/resolve/main/solo-traveller_16875043.png" alt="LoRA"> Celebrity_LoRa_Mix</h1>', | |
| elem_id="title" | |
| ) | |
| loras_state = gr.State(LORA_LIST) | |
| selected_lora_indices = gr.State([]) | |
| # Main input prompt box | |
| prompt = gr.Textbox(label="Prompt", placeholder="Type a prompt after selecting a LoRA") | |
| # LoRA selectors, sliders and images - built modularly here... | |
| # Advanced parameters | |
| with gr.Accordion("Advanced Settings", open=True): | |
| cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5) | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) | |
| width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=768) | |
| height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024) | |
| randomize_seed = gr.Checkbox(True, label="Randomize seed") | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
| generate_button = gr.Button("Generate", variant="primary", elem_classes=["button_total"]) | |
| output_img = gr.Image(interactive=False, show_share_button=False) | |
| progress_bar = gr.Markdown(visible=False) | |
| # Bind callbacks here (your existing logic, updated variable names) | |
| app.queue(concurrency_count=3).launch() | |