Spaces:
Running
on
Zero
Running
on
Zero
| # ===== ZeroGPU 超时优化终极版 ===== | |
| try: | |
| import spaces | |
| SPACES_AVAILABLE = True | |
| print("✅ ZeroGPU mode enabled") | |
| except ImportError: | |
| SPACES_AVAILABLE = False | |
| print("⚠️ Running in regular mode") | |
| import os | |
| from datetime import datetime | |
| import random | |
| import torch | |
| import gradio as gr | |
| from diffusers import AutoPipelineForText2Image, FlowMatchEulerDiscreteScheduler | |
| from PIL import Image | |
| import traceback | |
| import numpy as np | |
| import gc | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # ===== 配置 ===== | |
| FIXED_MODEL = "aoxo/flux.1dev-abliterated" | |
| SAVE_DIR = "generated_images" | |
| os.makedirs(SAVE_DIR, exist_ok=True) | |
| STYLE_PRESETS = { | |
| "None": "", | |
| "Realistic": "photorealistic, detailed", | |
| "Anime": "anime style, high quality", | |
| "Comic": "comic book style", | |
| "Watercolor": "watercolor painting" | |
| } | |
| # ===== 全局变量 ===== | |
| pipeline = None | |
| device = None | |
| model_loaded = False | |
| def cleanup_memory(): | |
| """激进的内存清理""" | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| def apply_spaces_decorator(func): | |
| """ZeroGPU 装饰器 - 60秒限制""" | |
| if SPACES_AVAILABLE: | |
| # ZeroGPU 实际只给 60 秒! | |
| return spaces.GPU(duration=60)(func) | |
| return func | |
| def enhance_prompt_minimal(prompt: str, style: str) -> str: | |
| """最小化提示词增强 - 严格控制长度""" | |
| style_suffix = STYLE_PRESETS.get(style, "") | |
| if style_suffix: | |
| enhanced = f"{prompt}, {style_suffix}, masterpiece" | |
| else: | |
| enhanced = f"{prompt}, masterpiece" | |
| # CLIP 硬限制: 77 tokens ≈ 200-250 字符 | |
| if len(enhanced) > 200: | |
| enhanced = prompt[:180] + ", masterpiece" | |
| print(f"⚠️ Prompt truncated to fit CLIP limit") | |
| return enhanced | |
| # ===== 分离模型初始化(不使用 GPU 装饰器)===== | |
| def initialize_model(): | |
| """模型初始化 - 不占用 GPU 时间""" | |
| global pipeline, device, model_loaded | |
| if model_loaded and pipeline is not None: | |
| return True | |
| try: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"🖥️ Device: {device}") | |
| print(f"📦 Loading: {FIXED_MODEL}") | |
| pipeline = AutoPipelineForText2Image.from_pretrained( | |
| FIXED_MODEL, | |
| dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| use_safetensors=True, | |
| ) | |
| pipeline.scheduler = FlowMatchEulerDiscreteScheduler.from_config( | |
| pipeline.scheduler.config | |
| ) | |
| # 关键优化:不用 offload,直接全部加载 | |
| pipeline = pipeline.to(device) | |
| # 只保留最必要的优化 | |
| if torch.cuda.is_available(): | |
| pipeline.enable_vae_slicing() | |
| pipeline.enable_vae_tiling() | |
| print("✅ Model ready") | |
| model_loaded = True | |
| return True | |
| except Exception as e: | |
| print(f"❌ Init failed: {e}") | |
| return False | |
| def generate_image_fast(prompt: str, style: str, negative_prompt: str, | |
| steps: int, cfg_scale: float, seed: int, | |
| width: int, height: int): | |
| """超快速生成 - 必须在 60 秒内完成""" | |
| try: | |
| print(f"⏱️ GPU timer started (60s limit)") | |
| if seed == -1: | |
| seed = random.randint(0, 999999) | |
| enhanced_prompt = enhance_prompt_minimal(prompt, style) | |
| if not negative_prompt: | |
| negative_prompt = "low quality, blurry" | |
| generator = torch.Generator("cpu").manual_seed(seed) | |
| print(f"🚀 Generating: {steps} steps, {width}x{height}") | |
| cleanup_memory() | |
| # 极简推理参数 | |
| with torch.inference_mode(): # 比 no_grad 更快 | |
| result = pipeline( | |
| prompt=enhanced_prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=cfg_scale, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| output_type="pil" | |
| ) | |
| image = result.images[0] | |
| del result | |
| cleanup_memory() | |
| print(f"✅ Done in <60s") | |
| return image, seed | |
| except Exception as e: | |
| cleanup_memory() | |
| print(f"❌ Error: {e}") | |
| raise e | |
| def generate_wrapper(prompt, style, neg_prompt, steps, cfg, seed, size_preset, progress=gr.Progress()): | |
| """包装函数 - 处理 UI 逻辑""" | |
| try: | |
| if not prompt.strip(): | |
| return None, "❌ Enter a prompt", "", None | |
| # 解析尺寸 | |
| if size_preset == "512x512 (Ultra Fast)": | |
| width = height = 512 | |
| elif size_preset == "768x768 (Fast)": | |
| width = height = 768 | |
| else: | |
| width = height = 1024 | |
| # 限制步数 | |
| steps = max(8, min(steps, 15)) | |
| progress(0.1, desc="Initializing...") | |
| # 预加载模型(不计入 GPU 时间) | |
| if not initialize_model(): | |
| return None, "❌ Model init failed", "", None | |
| progress(0.2, desc="Generating (30-50s)...") | |
| # 调用 GPU 函数 | |
| image, actual_seed = generate_image_fast( | |
| prompt, style, neg_prompt, steps, cfg, seed, width, height | |
| ) | |
| progress(0.9, desc="Saving...") | |
| filename = f"IMG_{actual_seed}.png" | |
| filepath = os.path.join(SAVE_DIR, filename) | |
| image.save(filepath) | |
| metadata = f"""Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | |
| Prompt: {prompt} | |
| Style: {style} | |
| Seed: {actual_seed} | |
| Steps: {steps} | CFG: {cfg} | |
| Size: {width}x{height} | |
| """ | |
| info = f"Seed: {actual_seed} | {width}×{height} | {steps} steps" | |
| progress(1.0, desc="Complete!") | |
| return image, info, metadata, image | |
| except Exception as e: | |
| cleanup_memory() | |
| error_msg = f"Generation failed: {str(e)[:100]}" | |
| print(f"❌ {error_msg}") | |
| return None, error_msg, "", None | |
| # ===== UI ===== | |
| def create_interface(): | |
| with gr.Blocks(title="Fast FLUX Generator") as interface: | |
| gr.HTML('<h1 style="text-align:center">⚡ Fast FLUX Generator</h1>') | |
| gr.HTML(''' | |
| <div style="background:#fff3cd;padding:10px;border-radius:8px;margin:10px 0;"> | |
| <strong>⚠️ ZeroGPU Limits:</strong><br> | |
| • 60 second GPU timeout (hard limit)<br> | |
| • Recommended: 512x512 or 768x768, 10-15 steps<br> | |
| • Keep prompts under 200 characters | |
| </div> | |
| ''') | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_input = gr.Textbox( | |
| label="Prompt (keep it short!)", | |
| placeholder="woman, portrait, detailed", | |
| lines=4, | |
| max_lines=4 | |
| ) | |
| negative_prompt_input = gr.Textbox( | |
| label="Negative Prompt", | |
| placeholder="low quality, blurry", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| style_input = gr.Radio( | |
| label="Style", | |
| choices=list(STYLE_PRESETS.keys()), | |
| value="Realistic" | |
| ) | |
| seed_input = gr.Number( | |
| label="Seed (-1 = random)", | |
| value=-1, | |
| precision=0 | |
| ) | |
| size_preset = gr.Radio( | |
| label="Size (smaller = faster)", | |
| choices=[ | |
| "512x512 (Ultra Fast)", | |
| "768x768 (Fast)", | |
| "1024x1024 (Slow)" | |
| ], | |
| value="768x768 (Fast)" | |
| ) | |
| steps_input = gr.Slider( | |
| label="Steps (10-15 recommended)", | |
| minimum=8, | |
| maximum=15, | |
| value=12, | |
| step=1 | |
| ) | |
| cfg_input = gr.Slider( | |
| label="CFG Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=3.5, | |
| step=0.5 | |
| ) | |
| generate_button = gr.Button( | |
| "🚀 GENERATE (30-50s)", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| image_output = gr.Image(label="Result", show_label=False) | |
| generation_info = gr.Textbox( | |
| label="Info", | |
| interactive=False, | |
| visible=True | |
| ) | |
| metadata_content = gr.Textbox(visible=False) | |
| current_image = gr.Image(visible=False) | |
| generate_button.click( | |
| fn=generate_wrapper, | |
| inputs=[ | |
| prompt_input, style_input, negative_prompt_input, | |
| steps_input, cfg_input, seed_input, size_preset | |
| ], | |
| outputs=[ | |
| image_output, generation_info, | |
| metadata_content, current_image | |
| ], | |
| show_progress=True | |
| ) | |
| prompt_input.submit( | |
| fn=generate_wrapper, | |
| inputs=[ | |
| prompt_input, style_input, negative_prompt_input, | |
| steps_input, cfg_input, seed_input, size_preset | |
| ], | |
| outputs=[ | |
| image_output, generation_info, | |
| metadata_content, current_image | |
| ], | |
| show_progress=True | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| print("🚀 Starting Fast FLUX Generator") | |
| print(f"🔧 Model: {FIXED_MODEL}") | |
| print(f"🔧 CUDA: {torch.cuda.is_available()}") | |
| # 预加载模型 | |
| print("📦 Pre-loading model...") | |
| initialize_model() | |
| app = create_interface() | |
| app.queue(max_size=3, default_concurrency_limit=1) | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| share=False | |
| ) |