cocoat commited on
Commit
44e8a3c
·
verified ·
1 Parent(s): f6f6634

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -56
app.py CHANGED
@@ -1,66 +1,129 @@
1
  import gradio as gr
2
- import time
 
 
 
 
 
 
 
 
3
 
4
- def generate(prompt):
5
- time.sleep(3)
6
- return ["/file/example1.png", "/file/example2.png"]
 
 
7
 
8
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
9
- with gr.Row():
10
- with gr.Column(scale=1, min_width=200):
11
- prompt = gr.Textbox(label="Prompt")
12
- generate_button = gr.Button("Generate")
13
- with gr.Column(scale=2):
14
- gallery = gr.Gallery(label="Generated Images").style(grid=2)
15
 
16
- # カスタムローディングスクリプトを読み込み時に仕込む
17
- demo.load(None, None, None, _js=r"""
18
- () => {
19
- const existing = document.getElementById("custom-loader");
20
- if (existing) existing.remove();
 
 
21
 
22
- const loader = document.createElement("div");
23
- loader.id = "custom-loader";
24
- loader.style.position = "fixed";
25
- loader.style.top = "20px";
26
- loader.style.right = "20px";
27
- loader.style.zIndex = 9999;
28
- loader.style.background = "rgba(255,255,255,0.95)";
29
- loader.style.border = "1px solid #ccc";
30
- loader.style.padding = "8px 12px";
31
- loader.style.borderRadius = "8px";
32
- loader.style.display = "flex";
33
- loader.style.alignItems = "center";
34
- loader.style.fontSize = "16px";
35
- loader.style.fontWeight = "bold";
36
- loader.style.color = "#734c36";
37
- loader.innerHTML = `<span id="progress-text"></span><img src="/icon.png" style="width:32px;height:32px;border-radius:50%;margin-left:8px;" />`;
38
 
39
- const target = document.querySelector("body");
40
- if (target) target.appendChild(loader);
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- const text = "in progress";
43
- let index = 0;
44
- const span = loader.querySelector("#progress-text");
45
- function animateText() {
46
- if (index <= text.length) {
47
- span.textContent = text.slice(0, index++);
48
- setTimeout(animateText, 100);
49
- }
50
- }
51
- animateText();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- // Gradioのデフォルトスピナーを完全に非表示にする
54
- const style = document.createElement("style");
55
- style.innerHTML = `
56
- .loading-wrap, .svelte-spinner, .loading {
57
- display: none !important;
58
- }
59
- `;
60
- document.head.appendChild(style);
61
- }
62
- """)
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- generate_button.click(fn=generate, inputs=prompt, outputs=gallery)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- demo.launch()
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ from diffusers import (
6
+ StableDiffusionXLPipeline,
7
+ EulerAncestralDiscreteScheduler,
8
+ DPMSolverMultistepScheduler
9
+ )
10
+ from huggingface_hub import hf_hub_download
11
 
12
+ # デバイスと型の設定
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
15
+ MAX_SEED = np.iinfo(np.int32).max
16
+ MAX_SIZE = 2048
17
 
18
+ # モデルファイルのダウンロード
19
+ model_path = hf_hub_download(
20
+ repo_id="cocoat/cocoamix",
21
+ filename="recocoamixXL3_coamixXL3.safetensors"
22
+ )
 
 
23
 
24
+ # パイプライン構築&スケジューラ設定
25
+ pipe = StableDiffusionXLPipeline.from_single_file(
26
+ model_path, torch_dtype=torch_dtype, use_safetensors=True
27
+ ).to(device)
28
+ euler_scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
29
+ dpm_scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
30
+ pipe.scheduler = euler_scheduler
31
 
32
+ # 生成履歴保持
33
+ history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # HTMLテーブル生成ヘルパー
36
+ def make_html_table(caption):
37
+ rows = caption.split("\n")
38
+ html = '<table style="width:100%;border-collapse:collapse;background:#fffaf1;">'
39
+ for row in rows:
40
+ if ": " in row:
41
+ key, val = row.split(": ", 1)
42
+ html += (
43
+ f'<tr><th style="text-align:left;border:1px solid #ddd;padding:8px;color:#8b5e3c;">{key}</th>'
44
+ f'<td style="border:1px solid #ddd;padding:8px;">{val}</td></tr>'
45
+ )
46
+ html += '</table>'
47
+ return html
48
 
49
+ # 推論
50
+ def infer(prompt, neg, seed, rand, w, h, cfg, steps, scheduler_type, progress=gr.Progress(track_tqdm=True)):
51
+ if rand:
52
+ seed = random.randint(0, MAX_SEED)
53
+ gen = torch.Generator(device=device).manual_seed(seed)
54
+ pipe.scheduler = euler_scheduler if scheduler_type == "Euler Ancestral" else dpm_scheduler
55
+ pipe.scheduler.set_timesteps(steps)
56
+ def cb(p, i, t, kw):
57
+ progress(i/steps, desc=f"Step {i}/{steps}")
58
+ return kw
59
+ out = pipe(
60
+ prompt=prompt,
61
+ negative_prompt=neg or None,
62
+ guidance_scale=cfg,
63
+ num_inference_steps=steps,
64
+ width=w,
65
+ height=h,
66
+ generator=gen,
67
+ callback_on_step_end=cb
68
+ )
69
+ img = out.images[0]
70
+ cap = (
71
+ f"Prompt: {prompt}\n"
72
+ f"Negative: {neg or 'None'}\n"
73
+ f"Seed: {seed}\n"
74
+ f"Size: {w}×{h}\n"
75
+ f"CFG: {cfg}\n"
76
+ f"Steps: {steps}\n"
77
+ f"Scheduler: {scheduler_type}"
78
+ )
79
+ history.insert(0, (img, cap))
80
+ progress(1.0, desc="Done!")
81
+ return img, [(im, make_html_table(c)) for im, c in history]
82
 
83
+ # カフェ風スタイルのCSS
84
+ def get_cafe_css():
85
+ return """
86
+ body {
87
+ background-color: #f4e1c1;
88
+ font-family: 'Georgia', serif;
89
+ }
90
+ #col-container {
91
+ background: #fffaf1;
92
+ padding: 20px;
93
+ border-radius: 16px;
94
+ box-shadow: 0 4px 12px rgba(0,0,0,0.1);
95
+ margin: auto;
96
+ max-width: 780px;
97
+ }
98
+ #col-container a {
99
+ color: #8b5e3c;
100
+ text-decoration: none;
101
+ border-bottom: 1px dotted #8b5e3c;
102
+ }
103
+ """
104
 
105
+ with gr.Blocks(css=get_cafe_css()) as demo:
106
+ with gr.Column(elem_id="col-container"):
107
+ gr.Markdown("## SDXL Base – cocoamixXL3 Demo")
108
+ gr.Markdown("[Link: Civitai](https://civitai.com/models/1553716?modelVersionId=1855218)")
109
+ prompt = gr.Textbox(lines=1, placeholder="Prompt…", value="1girl, cocoart, masterpiece, anime,")
110
+ neg = gr.Textbox(lines=1, placeholder="Negative prompt", value="low quality, worst quality, bad shadow, lowres,")
111
+ seed_sl = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
112
+ rand = gr.Checkbox(True, label="Randomize seed")
113
+ width = gr.Slider(256, MAX_SIZE, step=32, value=512, label="Width")
114
+ height = gr.Slider(256, MAX_SIZE, step=32, value=512, label="Height")
115
+ cfg = gr.Slider(1.0, 30.0, step=0.1, value=7.5, label="CFG Scale")
116
+ steps = gr.Slider(1, 50, step=1, value=12, label="Steps")
117
+ scheduler_type = gr.Radio(["Euler Ancestral", "DPM++ 2M SDE"], value="Euler Ancestral", label="Scheduler")
118
+ run = gr.Button("Generate")
119
+ img_out = gr.Image()
120
+ gallery = gr.Gallery(label="生成履歴", columns=4, height=280, show_label=False, interactive=True, type="pil")
121
 
122
+ run.click(
123
+ fn=infer,
124
+ inputs=[prompt, neg, seed_sl, rand, width, height, cfg, steps, scheduler_type],
125
+ outputs=[img_out, gallery]
126
+ )
127
+
128
+ demo.queue()
129
+ demo.launch()