meme-generator / app.py
boheng.xie
change main file
b733220
# app.py
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image, ImageFilter
# 在Space环境中检测设备
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
# 初始化模型(适配Space环境)
model = StableDiffusionPipeline.from_pretrained(
"stabilityai/sd-turbo",
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
use_safetensors=True,
safety_checker=None,
variant="fp16" if device == "cuda" else None,
low_cpu_mem_usage=True
).to(device)
# 性能优化
if device == "cuda":
model.enable_xformers_memory_efficient_attention()
model.enable_model_cpu_offload()
def validate_input(prompt):
"""输入验证"""
if not prompt:
raise ValueError("请输入有效描述")
if len(prompt) > 100:
raise ValueError("提示词过长(最多100字符)")
if not any('\u4e00' <= c <= '\u9fff' for c in prompt):
raise ValueError("请至少包含一个中文字符")
return prompt.strip()
def post_process(image):
"""图像后处理"""
return image.filter(ImageFilter.SHARPEN).filter(ImageFilter.UnsharpMask(radius=2, percent=150))
def generate(prompt):
try:
valid_prompt = validate_input(prompt)
steps = 4 if device == "cuda" else 15
image = model(
valid_prompt,
num_inference_steps=steps,
guidance_scale=2.0,
height=768 if device == "cuda" else 512,
width=768 if device == "cuda" else 512
).images[0]
return post_process(image), "🎉 生成成功!点击图片可下载"
except Exception as e:
return None, f"❌ 错误:{str(e)}"
# 界面构建
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 🎭 AI表情包工坊")
with gr.Row():
input_box = gr.Textbox(label="输入描述", placeholder="例如:打工人的周一早晨...")
generate_btn = gr.Button("生成", variant="primary")
with gr.Row():
image_out = gr.Image(label="生成结果", show_label=False, type="pil")
status_box = gr.Textbox(label="状态", interactive=False)
gr.Examples(
examples=[["熊猫头说'我太难了'"], ["流泪猫猫头配文'真的栓Q'"]],
inputs=input_box
)
generate_btn.click(
generate,
inputs=input_box,
outputs=[image_out, status_box]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")