File size: 6,972 Bytes
153fa23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4268f1
153fa23
 
44a4152
 
 
 
40b8efc
44a4152
 
 
 
 
 
 
 
40b8efc
 
153fa23
 
 
 
 
 
 
 
 
 
 
40b8efc
153fa23
 
 
 
 
 
44a4152
153fa23
40b8efc
153fa23
 
 
40b8efc
153fa23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d85dd6
153fa23
 
 
ae99aa5
153fa23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44a4152
153fa23
 
 
 
 
 
 
 
 
 
 
 
 
 
0d85dd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import os
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers.pipelines.glm_image import GlmImagePipeline
from PIL import Image

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

# Load model
pipe = GlmImagePipeline.from_pretrained(
    "zai-org/GLM-Image",
    torch_dtype=torch.bfloat16,
).to("cuda")


@spaces.GPU(duration=120)
def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, 
          num_inference_steps=50, guidance_scale=1.5, progress=gr.Progress(track_tqdm=True)):
    """Main inference function"""
    print("Randomizing seed")
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    # Ensure dimensions are multiples of 32
    width = (width // 32) * 32
    height = (height // 32) * 32
    
    generator = torch.Generator(device="cuda").manual_seed(seed)

    print("preparing iages")
    # Prepare image list for image-to-image mode
    image_list = None
    if input_images is not None and len(input_images) > 0:
        image_list = []
        for item in input_images:
            img = item[0] if isinstance(item, tuple) else item
            if isinstance(img, str):
                img = Image.open(img).convert("RGB")
            elif isinstance(img, Image.Image):
                img = img.convert("RGB")
            image_list.append(img)
    print("handling kwargs")
    pipe_kwargs = {
        "prompt": prompt,
        "height": height,
        "width": width,
        "num_inference_steps": num_inference_steps,
        "guidance_scale": guidance_scale,
        "generator": generator,
    }
    print("adding images")
    # Add images for image-to-image mode
    if image_list is not None:
        pipe_kwargs["image"] = image_list
    print("running kwargs")
    image = pipe(**pipe_kwargs).images[0]
    
    return image, seed


def update_dimensions_from_image(image_list):
    """Update width/height sliders based on uploaded image aspect ratio.
    Keeps dimensions proportional with both sides as multiples of 32."""
    if image_list is None or len(image_list) == 0:
        return 1024, 1024  # Default dimensions
    
    # Get the first image to determine dimensions
    item = image_list[0]
    img = item[0] if isinstance(item, tuple) else item
    
    if isinstance(img, str):
        img = Image.open(img)
    
    img_width, img_height = img.size
    aspect_ratio = img_width / img_height
    
    if aspect_ratio >= 1:  # Landscape or square
        new_width = 1024
        new_height = int(1024 / aspect_ratio)
    else:  # Portrait
        new_height = 1024
        new_width = int(1024 * aspect_ratio)
    
    # Round to nearest multiple of 32 (GLM-Image requirement)
    new_width = round(new_width / 32) * 32
    new_height = round(new_height / 32) * 32
    
    # Ensure within valid range
    new_width = max(256, min(MAX_IMAGE_SIZE, new_width))
    new_height = max(256, min(MAX_IMAGE_SIZE, new_height))
    
    return new_width, new_height

css = """
#col-container {
    margin: 0 auto;
    max-width: 1200px;
}
.gallery-container img {
    object-fit: contain;
}
"""

with gr.Blocks() as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown("""# GLM-Image
GLM-Image is a hybrid auto-regressive + diffusion 9B parameters model by z.ai
[[Model](https://huggingface.co/zai-org/GLM-Image)]
        """)
        
        with gr.Row():
            with gr.Column():
                prompt = gr.Text(
                    label="Prompt",
                    show_label=False,
                    max_lines=4,
                    placeholder="Enter your prompt (for text-to-image) or editing instructions (for image-to-image)",
                    container=False,
                    scale=3
                )
                
                run_button = gr.Button("🎨 Generate", variant="primary", scale=1)
                
                with gr.Accordion("📷 Input Image(s) (optional - for image-to-image mode)", open=True):
                    input_images = gr.Gallery(
                        label="Input Image(s)",
                        type="pil",
                        columns=3,
                        rows=1,
                        elem_classes="gallery-container"
                    )
                    gr.Markdown("*Upload one or more images for image-to-image generation. Leave empty for text-to-image mode.*")
                
                with gr.Accordion("⚙️ Advanced Settings", open=False):
                    seed = gr.Slider(
                        label="Seed",
                        minimum=0,
                        maximum=MAX_SEED,
                        step=1,
                        value=42,
                    )
                    
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                    
                    with gr.Row():
                        width = gr.Slider(
                            label="Width",
                            minimum=256,
                            maximum=MAX_IMAGE_SIZE,
                            step=32,
                            value=1024,
                            info="Must be a multiple of 32"
                        )
                        
                        height = gr.Slider(
                            label="Height",
                            minimum=256,
                            maximum=MAX_IMAGE_SIZE,
                            step=32,
                            value=1024,
                            info="Must be a multiple of 32"
                        )
                    
                    with gr.Row():
                        num_inference_steps = gr.Slider(
                            label="Number of inference steps",
                            minimum=1,
                            maximum=100,
                            step=1,
                            value=50,
                        )
                        
                        guidance_scale = gr.Slider(
                            label="Guidance scale",
                            minimum=0.0,
                            maximum=10.0,
                            step=0.1,
                            value=1.5,
                        )
                
            with gr.Column():
                result = gr.Image(label="Result", show_label=False)

    # Auto-update dimensions when images are uploaded
    input_images.upload(
        fn=update_dimensions_from_image,
        inputs=[input_images],
        outputs=[width, height]
    )

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[prompt, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale],
        outputs=[result, seed]
    )

demo.launch(theme=gr.themes.Citrus(), css=css)