File size: 8,429 Bytes
c03ae6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""
app.py - Mega-Scale Refactor for Celebrity_LoRa_Mix Space

Features:
- Modular imports and dependency management
- Advanced error handling with user-facing messages
- Async-ready pipeline integration with fallback sync support
- Mobile-first responsive layout with concise UX messaging
- Leverages helpers.py and lora_manager.py for clarity and reuse

Author: Helios Automation Alchemist
"""

import os
import sys
import json
import logging
import random
import time
import asyncio
from typing import List

import torch
import gradio as gr
import pandas as pd
import requests

from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
from transformers import CLIPTokenizer, CLIPProcessor, CLIPModel, LongformerTokenizer, LongformerModel
from PIL import Image

# Custom modules
import helpers
from lora_manager import LoRAManager

# === Config ===
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s')
logger = logging.getLogger(__name__)

MAX_SEED = 2**32 - 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16 if DEVICE.type == 'cuda' else torch.float32

# === Model & tokenizer loading ===
def load_tokenizers_and_models():
    try:
        clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16")
        clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
        clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
        logger.info("CLIP tokenizer & model loaded.")
        
        longformer_tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
        longformer_model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
        logger.info("Longformer tokenizer & model loaded.")
        return clip_tokenizer, clip_processor, clip_model, longformer_tokenizer, longformer_model
    except Exception as e:
        logger.error(f"Tokenizer/model load failed: {e}")
        sys.exit(1)

clip_tokenizer, clip_processor, clip_model, longformer_tokenizer, longformer_model = load_tokenizers_and_models()

# === Load prompts and LoRAs ===
def load_prompts_and_loras():
    try:
        prompts = pd.read_csv("prompts.csv", header=None).values.flatten()
    except FileNotFoundError:
        logger.warning("prompts.csv missing, defaulting to empty prompts.")
        prompts = []
    try:
        with open("loras.json", "r") as f:
            loras = json.load(f)
    except FileNotFoundError:
        logger.warning("loras.json missing, defaulting to empty LoRA list.")
        loras = []
    return prompts, loras

PROMPT_VALUES, LORA_LIST = load_prompts_and_loras()

# === Initialize Diffusion Pipeline with retry and fallback ===
def initialize_pipeline(base_model="sayakpaul/FLUX.1-merged", max_retries=3):
    for attempt in range(max_retries):
        try:
            taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=DTYPE).to(DEVICE)
            good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=DTYPE).to(DEVICE)
            
            pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=DTYPE, vae=taef1).to(DEVICE)
            pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
                base_model,
                vae=good_vae,
                transformer=pipe.transformer,
                text_encoder=pipe.text_encoder,
                tokenizer=pipe.tokenizer,
                text_encoder_2=pipe.text_encoder_2,
                tokenizer_2=pipe.tokenizer_2,
                torch_dtype=DTYPE
            )
            pipe.flux_pipe_call_that_returns_an_iterable_of_images = helpers.flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
            
            logger.info("Diffusion pipeline loaded successfully.")
            return pipe, pipe_i2i
        except Exception as e:
            logger.warning(f"Attempt {attempt + 1} failed: {e}")
            time.sleep(5)
    logger.error("Failed to load diffusion pipeline after retries.")
    sys.exit(1)

pipe, pipe_i2i = initialize_pipeline()

# === LoRA Manager for adapter lifecycle ===
lora_manager = LoRAManager(LORA_LIST)

# === Core business logic ===
def process_input(text: str, max_length: int=4096):
    if not text or not text.strip():
        raise gr.Error("Prompt cannot be empty.")
    return longformer_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=max_length)

@helpers.async_run_if_possible
def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
    pipe.to(DEVICE)
    generator = torch.Generator(device=DEVICE).manual_seed(seed)
    with helpers.calculate_duration("Generating image"):
        for step_idx, img in enumerate(pipe.flux_pipe_call_that_returns_an_iterable_of_images(
                prompt=prompt,
                num_inference_steps=steps,
                guidance_scale=cfg_scale,
                width=width,
                height=height,
                generator=generator,
                joint_attention_kwargs={"scale": 1.0},
                output_type="pil",
                good_vae=pipe.vae,
            )):
            yield img, seed, gr.update(value=f"Step {step_idx + 1}/{steps}", visible=True)

@spaces.GPU(duration=75)
def run_lora(prompt, cfg_scale, steps, selected_loras_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4,
             randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
    if not selected_loras_indices:
        raise gr.Error("Select at least one LoRA.")

    selected_loras = [loras_state[i] for i in selected_loras_indices]

    # Compose prompt with LoRA trigger words
    prepend_words = []
    append_words = []
    for lora in selected_loras:
        tw = lora.get("trigger_word", "")
        if tw:
            if lora.get("trigger_position") == "prepend":
                prepend_words.append(tw)
            else:
                append_words.append(tw)
    prompt_mash = " ".join(prepend_words + [prompt] + append_words)

    if randomize_seed or seed == 0:
        seed = random.randint(0, MAX_SEED)

    logger.info(f"Generating with prompt: {prompt_mash} Seed: {seed}")

    try:
        lora_manager.set_active_loras(pipe, selected_loras, [lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4])
    except Exception as e:
        logger.error(f"LoRA weight loading failed: {e}")
        raise gr.Error(f"Failed to load LoRA weights: {str(e)}")

    return generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)

# === UI Setup ===
MOBILE_CSS = '''
@media (max-width: 600px) {
  .gr-row { flex-direction: column !important; }
  .button_total { width: 100% !important; }
}
'''

font = [gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]

with gr.Blocks(theme=gr.themes.Soft(font=font), css=MOBILE_CSS, delete_cache=(128, 256)) as app:
    # Title and app state
    gr.HTML(
        '<h1><img src="https://huggingface.co/spaces/keltezaa/Celebrity_LoRa_Mix/resolve/main/solo-traveller_16875043.png" alt="LoRA"> Celebrity_LoRa_Mix</h1>',
        elem_id="title"
    )
    loras_state = gr.State(LORA_LIST)
    selected_lora_indices = gr.State([])

    # Main input prompt box
    prompt = gr.Textbox(label="Prompt", placeholder="Type a prompt after selecting a LoRA")

    # LoRA selectors, sliders and images - built modularly here...

    # Advanced parameters
    with gr.Accordion("Advanced Settings", open=True):
        cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5)
        steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
        width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=768)
        height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
        randomize_seed = gr.Checkbox(True, label="Randomize seed")
        seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)

    generate_button = gr.Button("Generate", variant="primary", elem_classes=["button_total"])
    output_img = gr.Image(interactive=False, show_share_button=False)
    progress_bar = gr.Markdown(visible=False)

    # Bind callbacks here (your existing logic, updated variable names)

app.queue(concurrency_count=3).launch()