File size: 10,759 Bytes
2706625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
"""
FLUX.2 Gradio App
画像とテキストプロンプトを入力して、FLUX.2で画像生成を行うWebアプリ
"""

import gc

import gradio as gr
import torch
from diffusers import Flux2Pipeline
from googletrans import LANGUAGES, Translator
from PIL import Image


class FLUX2App:
    def __init__(self):
        self.pipe = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
        self.repo_id = "diffusers/FLUX.2-dev-bnb-4bit"

    def load_model(self):
        """モデルを初回のみロード"""
        if self.pipe is None:
            print("モデルをロード中...")
            self.pipe = Flux2Pipeline.from_pretrained(
                self.repo_id,
                torch_dtype=self.torch_dtype
            ).to(self.device)

            # メモリ最適化
            self.pipe.enable_model_cpu_offload()
            print("モデルのロード完了")
        return self.pipe

    def generate_image(
        self,
        prompt: str,
        input_image: Image.Image,
        num_steps: int = 28,
        guidance_scale: float = 4.0,
        seed: int = 42,
        width: int = 1024,
        height: int = 768,
        progress=gr.Progress()
    ):
        """
        画像を生成

        Args:
            prompt: テキストプロンプト
            input_image: 入力画像(任意)
            num_steps: デノイジングステップ数
            guidance_scale: ガイダンススケール
            seed: 乱数シード
            width: 出力画像の幅
            height: 出力画像の高さ
            progress: Gradio Progress tracker

        Returns:
            生成された画像、ステータスメッセージ
        """
        try:
            # プロンプトチェック
            if not prompt or prompt.strip() == "":
                error_msg = "❌ エラー: プロンプトを入力してください"
                progress(0, desc=error_msg)
                return None, error_msg

            # モデルロード
            progress(0.1, desc="🔄 モデルをロード中...")
            pipe = self.load_model()

            # 生成パラメータ
            progress(0.2, desc="⚙️ パラメータを設定中...")
            generator = torch.Generator(device=self.device).manual_seed(seed)

            # 入力画像の処理
            images_input = [input_image] if input_image is not None else None

            # 画像生成
            progress(0.3, desc=f"🎨 画像生成中... (0/{num_steps} steps)")
            print(f"生成開始: prompt='{prompt[:50]}...', steps={num_steps}, guidance={guidance_scale}")

            # コールバック関数でプログレスを更新
            def callback(pipe, step_index, timestep, callback_kwargs):
                progress_value = 0.3 + (0.6 * (step_index + 1) / num_steps)
                progress(progress_value, desc=f"🎨 画像生成中... ({step_index + 1}/{num_steps} steps)")
                return callback_kwargs

            result = pipe(
                prompt=prompt,
                image=images_input,
                generator=generator,
                num_inference_steps=num_steps,
                guidance_scale=guidance_scale,
                width=width,
                height=height,
                callback_on_step_end=callback,
            )

            output_image = result.images[0]

            # メモリクリア
            progress(0.95, desc="🧹 メモリをクリア中...")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                gc.collect()

            success_msg = f"✅ 生成完了! (steps={num_steps}, guidance={guidance_scale}, seed={seed})"
            progress(1.0, desc=success_msg)
            return output_image, success_msg

        except torch.cuda.OutOfMemoryError as e:
            error_msg = f"❌ VRAM不足エラー: メモリが足りません。ステップ数や解像度を下げてください。\n詳細: {str(e)}"
            print(error_msg)
            progress(0, desc="❌ VRAM不足エラー")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                gc.collect()
            return None, error_msg
        except Exception as e:
            error_msg = f"❌ エラーが発生しました: {type(e).__name__}\n詳細: {str(e)}"
            print(error_msg)
            progress(0, desc=f"❌ {type(e).__name__}")
            return None, error_msg


async def translate_to_english(prompt):
    """プロンプトを英語に翻訳"""
    if not prompt or prompt.strip() == "":
        return "", "⚠️ プロンプトが空です"

    try:
        translator = Translator()
        # awaitを使用してコルーチンを実行
        translated = await translator.translate(prompt, src='ja', dest='en')
        lang_name = LANGUAGES.get('ja', 'Japanese')
        return translated.text, f"✅ 翻訳完了: {lang_name} → English\n原文: {prompt}\n翻訳: {translated.text}"

    except Exception as e:
        error_msg = f"❌ 翻訳エラー: {type(e).__name__}: {str(e)}"
        print(error_msg)
        return prompt, error_msg


def create_ui():
    """Gradio UIを作成"""
    app = FLUX2App()

    with gr.Blocks(title="FLUX.2 画像生成", theme=gr.themes.Soft()) as demo:
        gr.Markdown(
            """
            # 🎨 FLUX.2 画像生成アプリ

            テキストプロンプトと入力画像(任意)から、FLUX.2で新しい画像を生成します。
            """
        )

        with gr.Row():
            with gr.Column(scale=1):
                # 入力エリア
                gr.Markdown("### 入力")

                prompt_input = gr.Textbox(
                    label="プロンプト",
                    placeholder="生成したい画像の説明を入力してください(例: a beautiful sunset over the ocean)",
                    lines=3,
                    value="a beautiful sunset over the ocean with vibrant colors"
                )

                with gr.Row():
                    translate_btn = gr.Button("🌐 英語に翻訳", size="sm")

                translate_status = gr.Textbox(
                    label="翻訳ステータス",
                    interactive=False,
                    visible=False
                )

                image_input = gr.Image(
                    label="入力画像(任意)",
                    type="pil",
                    sources=["upload", "clipboard"]
                )

                with gr.Accordion("詳細設定", open=False):
                    num_steps = gr.Slider(
                        minimum=10,
                        maximum=50,
                        value=28,
                        step=1,
                        label="ステップ数(多いほど高品質だが時間がかかる)"
                    )

                    guidance_scale = gr.Slider(
                        minimum=1.0,
                        maximum=10.0,
                        value=4.0,
                        step=0.5,
                        label="ガイダンススケール(高いほどプロンプトに忠実)"
                    )

                    seed = gr.Number(
                        label="シード値(再現性確保)",
                        value=42,
                        precision=0
                    )

                    with gr.Row():
                        width = gr.Slider(
                            minimum=512,
                            maximum=2048,
                            value=1024,
                            step=64,
                            label="幅"
                        )
                        height = gr.Slider(
                            minimum=512,
                            maximum=2048,
                            value=768,
                            step=64,
                            label="高さ"
                        )

                generate_btn = gr.Button("🎨 生成", variant="primary", size="lg")

            with gr.Column(scale=1):
                # 出力エリア
                gr.Markdown("### 出力")

                output_image = gr.Image(
                    label="生成画像",
                    type="pil"
                )

                status_text = gr.Textbox(
                    label="ステータス",
                    interactive=False
                )

        # サンプル例
        gr.Markdown("### 📝 サンプルプロンプト例")
        gr.Examples(
            examples=[
                ["a photo of a forest with mist swirling around the tree trunks"],
                ["a clean monochrome CAD-style technical line drawing"],
                ["a beautiful landscape with mountains and a lake at sunset"],
                ["an astronaut riding a horse on the moon"],
                ["a cute cat wearing sunglasses, digital art"],
            ],
            inputs=[prompt_input],
            label="クリックしてプロンプトをセット"
        )

        # イベント設定
        # 翻訳ボタン
        async def on_translate(prompt):
            translated, status = await translate_to_english(prompt)
            return translated, status, gr.update(visible=True)

        translate_btn.click(
            fn=on_translate,
            inputs=[prompt_input],
            outputs=[prompt_input, translate_status, translate_status]
        )

        # 生成ボタン
        generate_btn.click(
            fn=app.generate_image,
            inputs=[
                prompt_input,
                image_input,
                num_steps,
                guidance_scale,
                seed,
                width,
                height,
            ],
            outputs=[output_image, status_text]
        )

        gr.Markdown(
            """
            ---
            **使い方:**
            1. プロンプトを入力(必須)
            2. 入力画像をアップロード(任意、編集モードの場合)
            3. 詳細設定を調整(任意)
            4. 「生成」ボタンをクリック

            **注意:**
            - 初回実行時はモデルのロードに時間がかかります
            - VRAM不足の場合はステップ数や解像度を下げてください
            """
        )

    return demo


if __name__ == "__main__":
    demo = create_ui()
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True
    )