alexander00001's picture
Update app.py
c50111a verified
raw
history blame
13.3 kB
try:
import spaces
SPACES_AVAILABLE = True
print("✅ Spaces available - ZeroGPU mode")
except ImportError:
SPACES_AVAILABLE = False
print("⚠️ Spaces not available - running in regular mode")
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import datetime
import io
import json
import os
from typing import Optional
# ======================
# 配置区(你只需修改这里即可扩展)
# ======================
# 1. 基础模型
BASE_MODEL = "SG161222/RealisticVisionV6.0"
# 2. 固定LoRA(不可选,自动加载)
FIXED_LORAS = [
("Lykon/epiCRealism_LoRA", 0.8), # 质量增强
("latent-consistency/lora-dreamshaper", 0.7), # 姿势控制
]
# 3. 风格模板(自动拼接到用户提示词前)
STYLE_PROMPTS = {
"None": "",
"Realistic": "photorealistic, ultra-detailed skin, natural lighting, 8k, professional photography, f/1.8, shallow depth of field, Canon EOS R5, ",
"Anime": "anime style, cel shading, vibrant colors, detailed eyes, studio ghibli, trending on pixiv, ",
"Comic": "comic book style, bold outlines, dynamic angles, comic panel, Marvel style, inked lines, ",
"Watercolor": "watercolor painting, soft brush strokes, translucent layers, artistic, painterly, paper texture, ",
}
# 4. 可选LoRA下拉菜单(用户可选1个,None表示清除)
OPTIONAL_LORAS = [
"None",
"Add Detail: https://huggingface.co/latent-consistency/lora-add-detail",
"Vintage Photo: https://huggingface.co/ckpt/LoRA-vintage-photo",
"Cinematic: https://huggingface.co/latent-consistency/lora-cinematic",
"Portrait Enhancer: https://huggingface.co/deforum/Portrait-Enhancer-LoRA",
"Soft Focus: https://huggingface.co/latent-consistency/lora-soft-focus",
]
# 解析可选LoRA的名称和ID
OPTIONAL_LORA_MAP = {}
for item in OPTIONAL_LORAS:
if item != "None":
name, url = item.split(": ", 1)
OPTIONAL_LORA_MAP[name] = url
else:
OPTIONAL_LORA_MAP["None"] = None
# 默认参数
DEFAULT_SEED = -1
DEFAULT_WIDTH = 1024
DEFAULT_HEIGHT = 1024
DEFAULT_LORA_SCALE = 0.8
DEFAULT_STEPS = 30
DEFAULT_CFG = 7.5
# ======================
# 全局变量:延迟加载模型
# ======================
pipe = None
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_pipeline():
global pipe
if pipe is None:
print("🚀 Loading base model...")
pipe = StableDiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False,
).to(device)
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload() # 适配ZeroGPU
print("✅ Base model loaded.")
return pipe
def unload_pipeline():
global pipe
if pipe is not None:
del pipe
torch.cuda.empty_cache()
pipe = None
print("🗑️ Pipeline unloaded.")
# ======================
# 主生成函数
# ======================
def generate_image(
prompt, negative_prompt, style, seed, width, height, optional_lora_name, lora_scale,
steps, cfg_scale
):
global pipe
# 加载模型(懒加载)
pipe = load_pipeline()
# 处理种子
if seed == -1:
seed = torch.randint(0, 2**32, (1,)).item()
generator = torch.Generator(device=device).manual_seed(seed)
# 拼接风格提示词
full_prompt = STYLE_PROMPTS[style] + prompt
full_negative_prompt = negative_prompt
# 加载固定LoRA(每次生成前都加载,确保状态正确)
for lora_id, scale in FIXED_LORAS:
pipe.load_lora_weights(lora_id, adapter_name=lora_id)
pipe.set_adapters([lora_id], adapter_weights=[scale])
# 加载可选LoRA(如果非None)
if optional_lora_name != "None":
lora_url = OPTIONAL_LORA_MAP[optional_lora_name]
pipe.load_lora_weights(lora_url, adapter_name=optional_lora_name)
pipe.set_adapters([lora_id for lora_id, _ in FIXED_LORAS] + [optional_lora_name],
adapter_weights=[scale for _, scale in FIXED_LORAS] + [lora_scale])
else:
# 清除所有可选LoRA,只保留固定
pipe.set_adapters([lora_id for lora_id, _ in FIXED_LORAS],
adapter_weights=[scale for _, scale in FIXED_LORAS])
# 生成图像
image = pipe(
prompt=full_prompt,
negative_prompt=full_negative_prompt,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
).images[0]
# 生成元数据
metadata = {
"prompt": full_prompt,
"negative_prompt": full_negative_prompt,
"base_model": BASE_MODEL,
"fixed_loras": [lora_id for lora_id, _ in FIXED_LORAS],
"optional_lora": optional_lora_name if optional_lora_name != "None" else None,
"lora_scale": lora_scale,
"seed": seed,
"steps": steps,
"cfg_scale": cfg_scale,
"style": style,
"width": width,
"height": height,
"timestamp": datetime.datetime.now().isoformat()
}
# 生成文件名
timestamp = datetime.datetime.now().strftime("%y%m%d%H%M")
filename_base = f"{seed}-{timestamp}"
# 保存为WebP(高质量)
img_buffer = io.BytesIO()
image.save(img_buffer, format="WEBP", quality=95, method=6)
img_buffer.seek(0)
# 保存元数据为TXT
metadata_buffer = io.StringIO()
json.dump(metadata, metadata_buffer, indent=2, ensure_ascii=False)
metadata_buffer.seek(0)
# 返回:图像、元数据、文件名
return (
image,
json.dumps(metadata, indent=2, ensure_ascii=False),
f"{filename_base}.webp",
f"{filename_base}.txt",
img_buffer.getvalue(),
metadata_buffer.getvalue().encode('utf-8')
)
# ======================
# Gradio UI
# ======================
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="green",
neutral_hue="slate",
).set(
body_background_fill="linear-gradient(135deg, #1e40af, #059669)",
button_primary_background_fill="white",
button_primary_text_color="#1e40af",
input_background_fill="rgba(255,255,255,0.9)",
text_size="lg",
),
css="""
body { font-family: 'Helvetica Neue', 'Segoe UI', 'Arial', sans-serif; }
.gr-button { font-family: 'Helvetica Neue', 'Arial', sans-serif; font-weight: 500; }
.gr-textarea { font-family: 'Consolas', 'Monaco', 'Courier New', monospace; }
""",
) as demo:
gr.Markdown(
"""
# 🎨 AI Photo Generator (RealisticVision + LoRA)
**PRO + ZeroGPU Optimized | Multi-LoRA | Style Templates | Metadata Export**
"""
)
with gr.Row():
with gr.Column(scale=3):
# a. 提示词输入框
prompt_input = gr.Textbox(
label="Prompt (Positive)",
placeholder="A beautiful woman, golden hour, soft sunlight...",
lines=5,
max_lines=20,
elem_classes=["gr-textarea"]
)
# b. 负提示词输入框
negative_prompt_input = gr.Textbox(
label="Negative Prompt",
placeholder="blurry, low quality, deformed, cartoon, anime, text, watermark...",
lines=5,
max_lines=20,
elem_classes=["gr-textarea"]
)
# c. 风格选择(单选)
style_radio = gr.Radio(
choices=list(STYLE_PROMPTS.keys()),
label="Style",
value="Realistic",
elem_classes=["gr-radio"]
)
# d. 种子选择
with gr.Row():
seed_input = gr.Slider(
minimum=-1,
maximum=99999999,
step=1,
value=DEFAULT_SEED,
label="Seed (-1 = Random)"
)
seed_reset = gr.Button("Reset Seed")
# e. 宽度选择
with gr.Row():
width_input = gr.Slider(
minimum=512,
maximum=1536,
step=64,
value=DEFAULT_WIDTH,
label="Width"
)
width_reset = gr.Button("Reset Width")
# f. 高度选择
with gr.Row():
height_input = gr.Slider(
minimum=512,
maximum=1536,
step=64,
value=DEFAULT_HEIGHT,
label="Height"
)
height_reset = gr.Button("Reset Height")
# g. LoRA选择(下拉)
optional_lora_dropdown = gr.Dropdown(
choices=list(OPTIONAL_LORA_MAP.keys()),
label="Optional LoRA",
value="None",
elem_classes=["gr-dropdown"]
)
# h. LoRA控制
with gr.Row():
lora_scale_slider = gr.Slider(
minimum=0.0,
maximum=1.5,
step=0.05,
value=DEFAULT_LORA_SCALE,
label="LoRA Scale"
)
lora_reset = gr.Button("Reset LoRA Scale")
# i. 功能控制(Steps & CFG)
with gr.Row():
steps_slider = gr.Slider(
minimum=10,
maximum=100,
step=1,
value=DEFAULT_STEPS,
label="Steps"
)
cfg_slider = gr.Slider(
minimum=1.0,
maximum=20.0,
step=0.5,
value=DEFAULT_CFG,
label="CFG Scale"
)
gen_reset = gr.Button("Reset Generation")
# m. 生成按钮
generate_btn = gr.Button("✨ Generate Image", variant="primary", size="lg")
with gr.Column(scale=2):
# j. 图片显示区
image_output = gr.Image(label="Generated Image", height=768, format="webp")
# k. 元数据显示区
metadata_output = gr.Textbox(
label="Metadata (JSON)",
lines=12,
max_lines=20,
elem_classes=["gr-textarea"]
)
# l. 下载按钮(并列)
with gr.Row():
download_img_btn = gr.Button("⬇️ Download Image (WebP)")
download_meta_btn = gr.Button("⬇️ Download Metadata (TXT)")
# 隐藏文件输出(用于下载)
hidden_img_file = gr.File(visible=False)
hidden_meta_file = gr.File(visible=False)
# ======================
# 事件绑定
# ======================
# 重置种子
seed_reset.click(fn=lambda: -1, outputs=seed_input)
# 重置宽度
width_reset.click(fn=lambda: DEFAULT_WIDTH, outputs=width_input)
# 重置高度
height_reset.click(fn=lambda: DEFAULT_HEIGHT, outputs=height_input)
# 重置LoRA缩放
lora_reset.click(fn=lambda: DEFAULT_LORA_SCALE, outputs=lora_scale_slider)
# 重置生成参数
gen_reset.click(
fn=lambda: (DEFAULT_STEPS, DEFAULT_CFG),
outputs=[steps_slider, cfg_slider]
)
# 生成
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input, negative_prompt_input, style_radio,
seed_input, width_input, height_input,
optional_lora_dropdown, lora_scale_slider,
steps_slider, cfg_slider
],
outputs=[
image_output, metadata_output,
hidden_img_file, hidden_meta_file,
hidden_img_file, hidden_meta_file
]
)
# 下载图片
download_img_btn.click(
fn=None,
inputs=[hidden_img_file],
outputs=None,
js="(f) => { const a = document.createElement('a'); a.href = f; a.download = f.split('/').pop(); document.body.appendChild(a); a.click(); document.body.removeChild(a); }"
)
# 下载元数据
download_meta_btn.click(
fn=None,
inputs=[hidden_meta_file],
outputs=None,
js="(f) => { const a = document.createElement('a'); a.href = f; a.download = f.split('/').pop(); document.body.appendChild(a); a.click(); document.body.removeChild(a); }"
)
# 设置文件下载(通过返回值触发)
generate_btn.change(
fn=lambda img_bytes, meta_bytes, img_name, meta_name: (
gr.File(value=io.BytesIO(img_bytes), label=img_name, visible=True),
gr.File(value=io.BytesIO(meta_bytes), label=meta_name, visible=True)
),
inputs=[hidden_img_file, hidden_meta_file, hidden_img_file, hidden_meta_file],
outputs=[hidden_img_file, hidden_meta_file]
)
# ======================
# 启动
# ======================
if __name__ == "__main__":
demo.launch()
```