File size: 5,910 Bytes
a5924d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os, requests, random, time
import gradio as gr

HF_API_KEY = os.getenv("HF_API_KEY")
HF_MODEL_URL = os.getenv(
    "HF_MODEL_URL",
    "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
)

STYLES = {
    "Neutral / Keine Vorgabe": "",
    "Fotorealistisch": "photorealistic, 8k, ultra-detailed, natural lighting, shallow depth of field",
    "Cinematic": "cinematic, moody lighting, 35mm film, bokeh, high contrast, dramatic shadows",
    "Digital Art": "digital art, concept art, trending on artstation, highly detailed, sharp focus",
    "Anime / Manga": "anime, clean lines, smooth shading, studio quality, vibrant colors",
    "Stylized Painting": "oil painting, impasto, rich texture, brush strokes, chiaroscuro",
    "Product Shot": "studio product photo, softbox lighting, seamless background, crisp details",
    "Architecture": "architectural visualization, photoreal, ultra-wide, global illumination",
}

SIZES = {
    "Quadrat 1:1 (768x768)": (768, 768, "1:1"),
    "Quer 16:9 (1024x576)": (1024, 576, "16:9"),
    "Quer 3:2 (960x640)": (960, 640, "3:2"),
    "Quer 4:3 (896x672)": (896, 672, "4:3"),
    "Hoch 9:16 (576x1024)": (576, 1024, "9:16"),
}

def build_prompt(prompt, style, negative):
    parts = [prompt.strip()]
    if style and STYLES.get(style):
        parts.append(STYLES[style])
    if negative and negative.strip():
        parts.append(f"Negative: ({negative.strip()})")
    return ", ".join([p for p in parts if p])

def generate(prompt, negative, style, size_label, steps, guidance, seed_opt, safety):
    if not HF_API_KEY:
        raise gr.Error("HF_API_KEY fehlt in den Space-Settings (Variables and secrets).")
    if not prompt or len(prompt.strip()) < 3:
        raise gr.Error("Bitte einen aussagekräftigen Prompt eingeben (mind. 3 Zeichen).")

    width, height, _ = SIZES[size_label]
    seed = None if seed_opt in (None, "", "zufällig") else int(seed_opt)
    if seed is None:
        seed = random.randint(0, 2**31 - 1)

    final_prompt = build_prompt(prompt, style, negative)
    body = {
        "inputs": final_prompt,
        "parameters": {
            "negative_prompt": negative or "",
            "guidance_scale": float(guidance),
            "num_inference_steps": int(steps),
            "width": int(width),
            "height": int(height),
            "seed": int(seed),
            "safety_checker": bool(safety),
        }
    }

    headers = {"Authorization": f"Bearer {HF_API_KEY}"}
    t0 = time.time()
    resp = requests.post(HF_MODEL_URL, headers=headers, json=body, timeout=600)
    dt = time.time() - t0
    if resp.status_code != 200:
        try:
            detail = resp.json()
        except Exception:
            detail = resp.text
        raise gr.Error(f"Fehler vom Inference-Dienst: {detail}")

    return resp.content, f"Seed: {seed} • Schritte: {steps} • Guidance: {guidance} • Größe: {width}x{height} • Dauer: {dt:.1f}s"

with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet"), css="""
.gradio-container { max-width: 1060px !important; margin: 0 auto; }
.meta { opacity:.7; font-size: .9em }
""") as demo:
    gr.Markdown("# Promtkatze – KI-Bildgenerator 🐾")

    with gr.Row():
        with gr.Column(scale=3):
            prompt = gr.Textbox(label="Prompt", placeholder="z. B. Ultra-realistische Katze im Astronautenanzug, golden hour, 50mm", lines=3)
            negative = gr.Textbox(label="Negative Prompt (optional)", placeholder="blurry, low quality, watermark, distorted anatomy", lines=2)
            style = gr.Dropdown(list(STYLES.keys()), value="Neutral / Keine Vorgabe", label="Stil")
            size = gr.Dropdown(list(SIZES.keys()), value="Quadrat 1:1 (768x768)", label="Bildgröße / Seitenverhältnis")
            with gr.Accordion("Fortgeschritten", open=False):
                steps = gr.Slider(10, 60, value=30, step=1, label="Schritte (Qualität vs. Zeit)")
                guidance = gr.Slider(1.0, 15.0, value=7.0, step=0.5, label="Guidance (Kohärenz)")
                seed = gr.Textbox(value="zufällig", label="Seed (Zahl oder 'zufällig')")
                safety = gr.Checkbox(value=True, label="Sicherheitsfilter aktivieren (falls vom Modell unterstützt)")
            run = gr.Button("Bild generieren", variant="primary")
        with gr.Column(scale=4):
            out = gr.Image(label="Ergebnis", type="pil")
            meta = gr.Markdown("", elem_classes=["meta"])
            gallery = gr.Gallery(label="Verlauf (Session)", show_label=True, columns=[4], height="auto")
            dl = gr.DownloadButton(label="Letztes Bild herunterladen", visible=False)

    examples = gr.Examples(
        examples=[
            ["Cinematic portrait of a medieval knight, dramatic rim lighting, volumetric fog", "blurry, deformed, extra fingers", "Cinematic", "Quer 3:2 (960x640)", 30, 7.0, "zufällig", True],
            ["Photorealistic macro shot of a dew-covered rose, 100mm, ultra-detailed", "lowres, jpeg artifacts, oversaturated", "Fotorealistisch", "Quadrat 1:1 (768x768)", 32, 7.5, "zufällig", True],
            ["Anime hero in a neon city at night, rain, dynamic pose, vibrant", "noisy, grayscale, watermark", "Anime / Manga", "Hoch 9:16 (576x1024)", 28, 8.0, "zufällig", True]
        ],
        inputs=[prompt, negative, style, size, steps, guidance, seed, safety]
    )

    state_hist = gr.State([])

    def _run(p, n, s, sz, st, g, sd, safe, hist):
        img_bytes, meta_txt = generate(p, n, s, sz, st, g, sd, safe)
        hist = (hist or []) + [img_bytes]
        return img_bytes, meta_txt, hist, gr.update(visible=True, value=("bild.png", img_bytes))

    run.click(_run, [prompt, negative, style, size, steps, guidance, seed, safety, state_hist],
              [out, meta, gallery, dl], queue=True)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))