File size: 10,514 Bytes
fed79e5
 
a94e6e4
 
 
fed79e5
a94e6e4
 
fed79e5
a94e6e4
 
 
 
 
 
 
 
 
 
 
fed79e5
 
a94e6e4
fed79e5
 
 
 
a94e6e4
 
 
fed79e5
 
 
 
a94e6e4
 
 
 
 
 
 
 
 
fed79e5
 
a94e6e4
 
 
 
 
fed79e5
 
 
 
 
 
 
 
 
 
a94e6e4
 
 
fed79e5
 
 
a94e6e4
fed79e5
 
 
 
a94e6e4
fed79e5
a94e6e4
 
fed79e5
fefe278
fed79e5
6592180
fefe278
 
 
 
 
 
fed79e5
fefe278
fed79e5
fefe278
b05690f
fefe278
fed79e5
 
fefe278
 
 
 
 
fed79e5
 
fefe278
 
fed79e5
fefe278
23f8bf6
3bf5836
fefe278
fed79e5
fefe278
 
 
 
fed79e5
fefe278
88c9fed
 
fefe278
fed79e5
 
 
 
fefe278
fed79e5
23f8bf6
fefe278
fed79e5
fefe278
fed79e5
fefe278
fed79e5
 
fefe278
6592180
fefe278
fed79e5
23f8bf6
 
fefe278
fed79e5
 
23f8bf6
 
 
 
 
 
 
 
 
 
 
6592180
fed79e5
 
fefe278
fed79e5
 
23f8bf6
fed79e5
23f8bf6
fed79e5
 
 
 
 
 
 
 
 
fefe278
fed79e5
 
 
 
 
 
 
fefe278
fed79e5
 
 
 
 
 
 
 
 
 
 
 
 
 
fefe278
 
fed79e5
fefe278
fed79e5
 
 
fefe278
fed79e5
 
 
 
 
 
 
fefe278
fed79e5
 
 
 
 
23f8bf6
fefe278
23f8bf6
fed79e5
 
 
23f8bf6
fefe278
fed79e5
a94e6e4
fed79e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a94e6e4
fed79e5
 
 
 
 
a94e6e4
fed79e5
 
 
 
 
 
a94e6e4
fed79e5
 
 
 
a94e6e4
fed79e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a94e6e4
 
fed79e5
 
 
 
 
 
 
 
 
 
 
a94e6e4
fed79e5
a94e6e4
fed79e5
a94e6e4
 
 
fed79e5
 
a94e6e4
 
 
 
 
fed79e5
a94e6e4
fed79e5
a94e6e4
 
 
fed79e5
 
a94e6e4
 
 
fefe278
a94e6e4
3bf5836
 
a94e6e4
fed79e5
 
 
 
 
 
 
a94e6e4
 
fed79e5
a94e6e4
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# ===== ZeroGPU 超时优化终极版 =====

try:
    import spaces
    SPACES_AVAILABLE = True
    print("✅ ZeroGPU mode enabled")
except ImportError:
    SPACES_AVAILABLE = False
    print("⚠️ Running in regular mode")

import os
from datetime import datetime
import random
import torch
import gradio as gr
from diffusers import AutoPipelineForText2Image, FlowMatchEulerDiscreteScheduler
from PIL import Image
import traceback
import numpy as np
import gc
import warnings
warnings.filterwarnings('ignore')

# ===== 配置 =====
FIXED_MODEL = "aoxo/flux.1dev-abliterated"
SAVE_DIR = "generated_images"
os.makedirs(SAVE_DIR, exist_ok=True)

STYLE_PRESETS = {
    "None": "",
    "Realistic": "photorealistic, detailed",
    "Anime": "anime style, high quality",
    "Comic": "comic book style",
    "Watercolor": "watercolor painting"
}

# ===== 全局变量 =====
pipeline = None
device = None
model_loaded = False


def cleanup_memory():
    """激进的内存清理"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()


def apply_spaces_decorator(func):
    """ZeroGPU 装饰器 - 60秒限制"""
    if SPACES_AVAILABLE:
        # ZeroGPU 实际只给 60 秒!
        return spaces.GPU(duration=60)(func)
    return func


def enhance_prompt_minimal(prompt: str, style: str) -> str:
    """最小化提示词增强 - 严格控制长度"""
    style_suffix = STYLE_PRESETS.get(style, "")
    
    if style_suffix:
        enhanced = f"{prompt}, {style_suffix}, masterpiece"
    else:
        enhanced = f"{prompt}, masterpiece"
    
    # CLIP 硬限制: 77 tokens ≈ 200-250 字符
    if len(enhanced) > 200:
        enhanced = prompt[:180] + ", masterpiece"
        print(f"⚠️ Prompt truncated to fit CLIP limit")
    
    return enhanced


# ===== 分离模型初始化(不使用 GPU 装饰器)=====
def initialize_model():
    """模型初始化 - 不占用 GPU 时间"""
    global pipeline, device, model_loaded
    
    if model_loaded and pipeline is not None:
        return True
    
    try:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"🖥️ Device: {device}")
        
        print(f"📦 Loading: {FIXED_MODEL}")
        
        pipeline = AutoPipelineForText2Image.from_pretrained(
            FIXED_MODEL,
            dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            use_safetensors=True,
        )
        
        pipeline.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
            pipeline.scheduler.config
        )
        
        # 关键优化:不用 offload,直接全部加载
        pipeline = pipeline.to(device)
        
        # 只保留最必要的优化
        if torch.cuda.is_available():
            pipeline.enable_vae_slicing()
            pipeline.enable_vae_tiling()
        
        print("✅ Model ready")
        model_loaded = True
        return True
        
    except Exception as e:
        print(f"❌ Init failed: {e}")
        return False


@apply_spaces_decorator
def generate_image_fast(prompt: str, style: str, negative_prompt: str, 
                       steps: int, cfg_scale: float, seed: int, 
                       width: int, height: int):
    """超快速生成 - 必须在 60 秒内完成"""
    try:
        print(f"⏱️ GPU timer started (60s limit)")
        
        if seed == -1:
            seed = random.randint(0, 999999)
        
        enhanced_prompt = enhance_prompt_minimal(prompt, style)
        
        if not negative_prompt:
            negative_prompt = "low quality, blurry"
        
        generator = torch.Generator("cpu").manual_seed(seed)
        
        print(f"🚀 Generating: {steps} steps, {width}x{height}")
        
        cleanup_memory()
        
        # 极简推理参数
        with torch.inference_mode():  # 比 no_grad 更快
            result = pipeline(
                prompt=enhanced_prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=steps,
                guidance_scale=cfg_scale,
                width=width,
                height=height,
                generator=generator,
                output_type="pil"
            )
        
        image = result.images[0]
        del result
        cleanup_memory()
        
        print(f"✅ Done in <60s")
        return image, seed
        
    except Exception as e:
        cleanup_memory()
        print(f"❌ Error: {e}")
        raise e


def generate_wrapper(prompt, style, neg_prompt, steps, cfg, seed, size_preset, progress=gr.Progress()):
    """包装函数 - 处理 UI 逻辑"""
    try:
        if not prompt.strip():
            return None, "❌ Enter a prompt", "", None
        
        # 解析尺寸
        if size_preset == "512x512 (Ultra Fast)":
            width = height = 512
        elif size_preset == "768x768 (Fast)":
            width = height = 768
        else:
            width = height = 1024
        
        # 限制步数
        steps = max(8, min(steps, 15))
        
        progress(0.1, desc="Initializing...")
        
        # 预加载模型(不计入 GPU 时间)
        if not initialize_model():
            return None, "❌ Model init failed", "", None
        
        progress(0.2, desc="Generating (30-50s)...")
        
        # 调用 GPU 函数
        image, actual_seed = generate_image_fast(
            prompt, style, neg_prompt, steps, cfg, seed, width, height
        )
        
        progress(0.9, desc="Saving...")
        
        filename = f"IMG_{actual_seed}.png"
        filepath = os.path.join(SAVE_DIR, filename)
        image.save(filepath)
        
        metadata = f"""Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
Prompt: {prompt}
Style: {style}
Seed: {actual_seed}
Steps: {steps} | CFG: {cfg}
Size: {width}x{height}
"""
        
        info = f"Seed: {actual_seed} | {width}×{height} | {steps} steps"
        
        progress(1.0, desc="Complete!")
        
        return image, info, metadata, image
        
    except Exception as e:
        cleanup_memory()
        error_msg = f"Generation failed: {str(e)[:100]}"
        print(f"❌ {error_msg}")
        return None, error_msg, "", None


# ===== UI =====
def create_interface():
    with gr.Blocks(title="Fast FLUX Generator") as interface:
        gr.HTML('<h1 style="text-align:center">⚡ Fast FLUX Generator</h1>')
        
        gr.HTML('''
            <div style="background:#fff3cd;padding:10px;border-radius:8px;margin:10px 0;">
                <strong>⚠️ ZeroGPU Limits:</strong><br>
                • 60 second GPU timeout (hard limit)<br>
                • Recommended: 512x512 or 768x768, 10-15 steps<br>
                • Keep prompts under 200 characters
            </div>
        ''')
        
        with gr.Row():
            with gr.Column(scale=2):
                prompt_input = gr.Textbox(
                    label="Prompt (keep it short!)",
                    placeholder="woman, portrait, detailed",
                    lines=4,
                    max_lines=4
                )
                
                negative_prompt_input = gr.Textbox(
                    label="Negative Prompt",
                    placeholder="low quality, blurry",
                    lines=2
                )
            
            with gr.Column(scale=1):
                style_input = gr.Radio(
                    label="Style",
                    choices=list(STYLE_PRESETS.keys()),
                    value="Realistic"
                )
                
                seed_input = gr.Number(
                    label="Seed (-1 = random)",
                    value=-1,
                    precision=0
                )
                
                size_preset = gr.Radio(
                    label="Size (smaller = faster)",
                    choices=[
                        "512x512 (Ultra Fast)",
                        "768x768 (Fast)",
                        "1024x1024 (Slow)"
                    ],
                    value="768x768 (Fast)"
                )
                
                steps_input = gr.Slider(
                    label="Steps (10-15 recommended)",
                    minimum=8,
                    maximum=15,
                    value=12,
                    step=1
                )
                
                cfg_input = gr.Slider(
                    label="CFG Scale",
                    minimum=1.0,
                    maximum=10.0,
                    value=3.5,
                    step=0.5
                )
                
                generate_button = gr.Button(
                    "🚀 GENERATE (30-50s)",
                    variant="primary",
                    size="lg"
                )
        
        image_output = gr.Image(label="Result", show_label=False)
        
        generation_info = gr.Textbox(
            label="Info",
            interactive=False,
            visible=True
        )
        
        metadata_content = gr.Textbox(visible=False)
        current_image = gr.Image(visible=False)
        
        generate_button.click(
            fn=generate_wrapper,
            inputs=[
                prompt_input, style_input, negative_prompt_input,
                steps_input, cfg_input, seed_input, size_preset
            ],
            outputs=[
                image_output, generation_info, 
                metadata_content, current_image
            ],
            show_progress=True
        )
        
        prompt_input.submit(
            fn=generate_wrapper,
            inputs=[
                prompt_input, style_input, negative_prompt_input,
                steps_input, cfg_input, seed_input, size_preset
            ],
            outputs=[
                image_output, generation_info, 
                metadata_content, current_image
            ],
            show_progress=True
        )
    
    return interface


if __name__ == "__main__":
    print("🚀 Starting Fast FLUX Generator")
    print(f"🔧 Model: {FIXED_MODEL}")
    print(f"🔧 CUDA: {torch.cuda.is_available()}")
    
    # 预加载模型
    print("📦 Pre-loading model...")
    initialize_model()
    
    app = create_interface()
    app.queue(max_size=3, default_concurrency_limit=1)
    
    app.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True,
        share=False
    )