File size: 6,701 Bytes
3fa7365
 
 
927acc2
2f4ae32
3fa7365
626bd75
d8173ca
927acc2
 
 
626bd75
927acc2
 
 
 
 
 
 
 
 
 
 
d8173ca
927acc2
d8173ca
927acc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fa7365
927acc2
3fa7365
 
 
 
 
 
2f4ae32
3fa7365
8d1b29c
927acc2
 
3fa7365
 
 
30e32ff
927acc2
30e32ff
 
 
3fa7365
30e32ff
 
 
 
 
 
 
 
 
 
 
 
3fa7365
 
 
927acc2
 
3fa7365
 
 
 
 
 
927acc2
3fa7365
 
 
927acc2
30e32ff
9395213
3fa7365
 
927acc2
3fa7365
927acc2
 
 
 
 
3fa7365
9395213
 
 
927acc2
 
 
 
3fa7365
 
927acc2
3fa7365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
927acc2
3fa7365
 
 
 
 
 
 
927acc2
3fa7365
 
 
 
 
 
 
 
 
927acc2
3fa7365
 
 
 
 
 
 
927acc2
3fa7365
 
2f4ae32
 
8725e45
927acc2
 
 
9395213
 
 
 
8725e45
 
 
 
 
 
927acc2
 
 
 
 
 
 
 
 
 
 
e542ef8
927acc2
 
3fa7365
 
0436c77
30e32ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import gradio as gr
import numpy as np
import random
import spaces  # [uncomment to use ZeroGPU]
from diffusers import DiffusionPipeline
import torch
import os
from huggingface_hub import login
from openai import OpenAI

# Initialize API keys and login
hf_token = os.getenv("space_token")
openai_api_key = os.getenv("openai_apikey")
login(token=hf_token)

def make_prompt(place):
    client = OpenAI(api_key=openai_api_key)
    messages = [
        {
            "role": "system",
            "content": """userの入力するplace_infoを基に、英単語を、3つ羅列してください。
                        その英単語を基に、男の人の画像生成を行います。英単語の順番は以下の通りです。
                        場所, 行っている動作, 背景の様子

                        ex: The ocean, swimming, with sharks

                        ##output format:
                          place_hoge, moving_hoge, background_hoge"""
        },
        {"role": "user", "content": "place_info: nature"},
        {"role": "assistant", "content": "forest, exploration, there are tigers"},
        {"role": "user", "content": "place_info: " + place}
    ]
    response = client.chat.completions.create(
        model="gpt-4",
        messages=messages,
        temperature=1,
    )
    # Assuming the assistant returns something like "mountains, hiking, clear sky"
    generated_content = response.choices[0].message.content.strip()
    prompt = "Purotan, short brown hair, bright smile, " + generated_content
    return prompt

# Set device and load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "black-forest-labs/FLUX.1-dev"  # Replace with your model

if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32

pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
pipe.load_lora_weights("purotan_1750.safetensors")

# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

# Define the image generation function
@spaces.GPU  # [uncomment to use ZeroGPU]
def generate_image(place, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
        
    prompt = make_prompt(place)
    
    generator = torch.Generator().manual_seed(seed)
    
    image = pipe(
        prompt=prompt, 
        guidance_scale=guidance_scale, 
        num_inference_steps=num_inference_steps, 
        width=width, 
        height=height,
        generator=generator
    ).images[0] 
    
    return image, seed

# CSS for styling
css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

# Define the Gradio interface
with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown("""
        # Text-to-Image Gradio Template
        ボタンを押して場所を選択し、画像を生成してください!
        """)
        
        # Place Selection Buttons
        with gr.Row():
            btn_nature = gr.Button("自然")
            btn_cityscape = gr.Button("都市景観")
            btn_fantasy = gr.Button("ファンタジー世界")
            btn_daily = gr.Button("日常生活")
            btn_space = gr.Button("宇宙")
        
        # Display Selected Place
        selected_place_display = gr.Markdown("**選択された場所:** 自然")
        
        # Run Button
        run_button = gr.Button("Run", scale=0)
        
        # Image Output
        result = gr.Image(label="Result", show_label=False)

        # Advanced Settings Accordion
        with gr.Accordion("Advanced Settings", open=False):
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            with gr.Row():
                
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=720,  # Adjust based on your model's capabilities
                )
                
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1280,  # Adjust based on your model's capabilities
                )
            
            with gr.Row():
                
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=10.0,
                    step=0.1,
                    value=3.5,  # Adjust based on your model's capabilities
                )
                
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=20,  # Adjust based on your model's capabilities
                )
        
        # Removed the gr.Examples section to fix the ValueError
        # If you wish to add examples, ensure they align with the input components

    # State to keep track of selected place
    selected_place = gr.State("自然")  # Default to "自然"

    # Define functions to set the selected place and update the display
    def set_place(place):
        return place, f"**選択された場所:** {place}"
    
    # Connect buttons to state setter functions using lambda
    btn_nature.click(fn=lambda: set_place("自然"), outputs=[selected_place, selected_place_display])
    btn_cityscape.click(fn=lambda: set_place("都市景観"), outputs=[selected_place, selected_place_display])
    btn_fantasy.click(fn=lambda: set_place("ファンタジー世界"), outputs=[selected_place, selected_place_display])
    btn_daily.click(fn=lambda: set_place("日常生活"), outputs=[selected_place, selected_place_display])
    btn_space.click(fn=lambda: set_place("宇宙"), outputs=[selected_place, selected_place_display])

    # Connect Run button to the image generation function
    run_button.click(
        fn=generate_image,
        inputs=[
            selected_place, 
            seed, 
            randomize_seed, 
            width, 
            height, 
            guidance_scale, 
            num_inference_steps
        ],
        outputs=[result, seed]
    )

demo.queue().launch()