File size: 3,151 Bytes
c27dc1b
 
 
 
 
da5103b
 
a245c5a
da5103b
 
 
 
1974c95
 
da5103b
c27dc1b
1974c95
da5103b
 
1974c95
 
 
da5103b
c27dc1b
da5103b
 
 
c27dc1b
da5103b
1974c95
 
 
da5103b
1974c95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27dc1b
da5103b
 
 
 
1974c95
da5103b
c27dc1b
da5103b
 
c27dc1b
1974c95
da5103b
 
 
 
c27dc1b
 
da5103b
 
a245c5a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# 修复 huggingface_hub 缺失 cached_download 问题(必须在导入 diffusers 之前)
import huggingface_hub
from huggingface_hub import hf_hub_download
huggingface_hub.cached_download = hf_hub_download  # monkey-patch BEFORE importing diffusers

import gradio as gr
import pandas as pd
import torch
from diffusers import StableDiffusionXLPipeline
from PIL import Image
import os
import zipfile
import tempfile
import uuid

# ==== 模型加载 ====
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    variant="fp16" if device == "cuda" else None,
).to(device)

# ==== 生成说明文本 ====
def generate_text(prompt):
    return f"这是一幅描绘:{prompt} 的画面,小朋友可以根据图像发挥想象力哦!"

# ==== 批量 CSV 处理函数 ====
def process_csv(file):
    if not file.name.endswith(".csv"):
        raise gr.Error("请上传一个有效的 CSV 文件(包含 prompt 列)")

    df = pd.read_csv(file.name)
    if "prompt" not in df.columns:
        raise gr.Error("CSV 文件中必须包含名为 'prompt' 的列")

    with tempfile.TemporaryDirectory() as tmpdir:
        output_dir = os.path.join(tmpdir, "output")
        os.makedirs(output_dir, exist_ok=True)

        for idx, row in df.iterrows():
            prompt = row["prompt"]
            image = pipe(prompt=prompt).images[0]
            image_path = os.path.join(output_dir, f"image_{idx+1}.png")
            image.save(image_path)

            text = generate_text(prompt)
            text_path = os.path.join(output_dir, f"text_{idx+1}.txt")
            with open(text_path, "w", encoding="utf-8") as f:
                f.write(text)

        zip_path = os.path.join(tmpdir, f"output_{uuid.uuid4().hex}.zip")
        with zipfile.ZipFile(zip_path, "w") as zipf:
            for file_name in os.listdir(output_dir):
                file_path = os.path.join(output_dir, file_name)
                zipf.write(file_path, arcname=file_name)

        return zip_path

# ==== 单张生成函数 ====
def text_to_image(prompt):
    image = pipe(prompt=prompt).images[0]
    return image

# ==== Gradio 界面 ====
with gr.Blocks() as demo:
    gr.Markdown("## 🎨 AI 图文生成器(批量和单图)")

    with gr.Tab("📂 批量生成(上传CSV)"):
        csv_input = gr.File(label="上传CSV文件(需包含 'prompt' 列)", file_types=[".csv"])
        csv_output = gr.File(label="下载生成的图文ZIP包")
        csv_btn = gr.Button("开始生成")
        csv_btn.click(fn=process_csv, inputs=csv_input, outputs=csv_output)

    with gr.Tab("🖼️ 单图生成(文本转图片)"):
        prompt_input = gr.Textbox(label="输入提示词", placeholder="如:一只飞翔在太空的小猫咪")
        image_output = gr.Image(label="生成的图像", type="pil")
        single_btn = gr.Button("立即生成")
        single_btn.click(fn=text_to_image, inputs=prompt_input, outputs=image_output)

if __name__ == "__main__":
    demo.launch()