File size: 8,002 Bytes
ae8c337
 
 
1989f45
cfc3ce6
29598a1
ae8c337
5b8270b
 
 
29598a1
5b8270b
 
 
29598a1
f8e1897
5b8270b
2d9a7e8
ae8c337
 
5b8270b
f8e1897
5b8270b
fa1301c
2d9a7e8
 
 
 
 
 
 
 
 
 
 
5c654d4
 
 
 
 
 
2d9a7e8
f8e1897
 
 
 
 
 
 
 
 
 
 
 
ebb220e
f8e1897
 
 
25c8061
 
f8e1897
 
5b8270b
 
8ab990d
b07f8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b8270b
b07f8ef
 
 
 
 
 
 
 
 
5b8270b
 
a2a7641
ebb220e
f8e1897
2d9a7e8
 
 
 
b58298b
5c654d4
cfc3ce6
 
 
 
5c654d4
cfc3ce6
 
 
 
 
5c654d4
 
 
 
 
 
 
 
 
 
 
 
f00ffec
 
5c654d4
 
f00ffec
5b8270b
 
 
 
d4cae39
5b8270b
 
 
 
 
 
5c654d4
46ebb59
 
5b8270b
 
d4cae39
5c654d4
d4cae39
5c654d4
 
d4cae39
b07f8ef
d4cae39
 
 
 
 
 
 
 
 
 
b07f8ef
 
 
 
 
 
 
 
d4cae39
cb81188
d4cae39
5b8270b
f00ffec
 
5b8270b
5c654d4
5b8270b
8ab990d
d4cae39
 
 
 
 
 
5b8270b
 
b07f8ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# PyTorch 2.8 (temporary hack)
import os

os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" "torchvision" spaces')
os.environ['DIFFUSERS_ENABLE_HUB_KERNELS']='yes'
import torchvision.transforms.functional as TF
# Actual demo code
import gradio as gr
import numpy as np
import spaces

import torch
import random
from PIL import Image

from pipeline import GenSIRR
from diffusers.utils import load_image
import torch.nn.functional as F
from optimization import optimize_pipeline_

MAX_SEED = np.iinfo(np.int32).max
from huggingface_hub import hf_hub_download

def pad_for_model(image: torch.Tensor, multiple: int):
    """Pad the tensor image so height/width are divisible by ``multiple``."""

    height, width = image.shape[-2:]
    pad_h = (multiple - height % multiple) % multiple
    pad_w = (multiple - width % multiple) % multiple
    if pad_h == 0 and pad_w == 0:
        return image, (0, 0)

    padded = F.pad(image.unsqueeze(0), (0, pad_w, 0, pad_h), mode="reflect").squeeze(0)
    return padded, (pad_h, pad_w)

def tensor_to_image(tensor: torch.Tensor) -> Image.Image:
    """Convert a [C,H,W] tensor in [0,1] to a PIL image."""
    # tensor = tensor.clamp(-1.0, 1.0) / 2.0 + 0.5
    tensor = tensor.clamp(0.0, 1.0)
    array = tensor.mul(255).byte().permute(1, 2, 0).cpu().numpy()
    return Image.fromarray(array)

def load_deepspeed_weights(model, checkpoint_path) -> None:
    """Load LoRA weights from a DeepSpeed ZeRO Stage 2 checkpoint into the model."""
    tensor_path = checkpoint_path
    # LOGGER.info("Loading ZeRO checkpoint from %s", tensor_path)
    raw_state = torch.load(tensor_path, map_location="cpu")
    module_state: Dict[str, torch.Tensor] = raw_state.get("module")
    if module_state is None:
        raise KeyError("Checkpoint is missing the 'module' state dict")

    # Remove the Lightning prefix so it matches the FluxKontext state dict.
    cleaned_state = {key[len("net_g."):]: value for key, value in module_state.items() if key.startswith("net_g.")}

    missing, unexpected = model.load_state_dict(cleaned_state, strict=True)

pipe = GenSIRR("black-forest-labs/FLUX.1-Kontext-dev")
load_deepspeed_weights(pipe, hf_hub_download(repo_id='lime-j/GenSIRR', filename="GenSIRR.pt"))
# pipe.transformer.fuse_qkv_projections()
# pipe.transformer.set_attention_backend("_flash_3_hub")
pipe = pipe.to("cuda")
# optimize_pipeline_(pipe, image=Image.new("RGB", (512, 512)), prompt='prompt')

@spaces.GPU
def infer(input_image, seed=42, randomize_seed=False, steps=28, progress=gr.Progress(track_tqdm=True)):
    """
    Perform image editing using the FLUX.1 Kontext pipeline.
    
    This function takes an input image and a text prompt to generate a modified version
    of the image based on the provided instructions. It uses the FLUX.1 Kontext model
    for contextual image editing tasks.
    
    Args:
        input_image (PIL.Image.Image): The input image to be edited. Will be converted
            to RGB format if not already in that format.
        prompt (str): Text description of the desired edit to apply to the image.
            Examples: "Remove glasses", "Add a hat", "Change background to beach".
        seed (int, optional): Random seed for reproducible generation. Defaults to 42.
            Must be between 0 and MAX_SEED (2^31 - 1).
        randomize_seed (bool, optional): If True, generates a random seed instead of
            using the provided seed value. Defaults to False.
        guidance_scale (float, optional): Controls how closely the model follows the
            prompt. Higher values mean stronger adherence to the prompt but may reduce
            image quality. Range: 1.0-10.0. Defaults to 2.5.
        steps (int, optional): Controls how many steps to run the diffusion model for.
            Range: 1-30. Defaults to 28.
        progress (gr.Progress, optional): Gradio progress tracker for monitoring
            generation progress. Defaults to gr.Progress(track_tqdm=True).
    
    Returns:
        tuple: A 3-tuple containing:
            - PIL.Image.Image: The generated/edited image
            - int: The seed value used for generation (useful when randomize_seed=True)
            - gr.update: Gradio update object to make the reuse button visible
    
    Example:
        >>> edited_image, used_seed, button_update = infer(
        ...     input_image=my_image,
        ...     prompt="Add sunglasses",
        ...     seed=123,
        ...     randomize_seed=False,
        ...     guidance_scale=2.5
        ... )
    """
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    size = 512
    input_image = input_image.convert("RGB")
    if input_image.width < input_image.height:
        input_image = input_image.resize((size, int(size * input_image.height / input_image.width)))
    else:
        input_image = input_image.resize((int(size * input_image.width / input_image.height), size))
    tensor = TF.to_tensor(input_image)
    with torch.inference_mode():
        original_size = tensor.shape[-2:]
        padded_tensor, padding = pad_for_model(tensor, 16)
        batch_cpu = padded_tensor.unsqueeze(0)
        batch_device = batch_cpu.to('cuda')
        output = pipe(
            image=batch_device, 
            width = input_image.size[0],
            height = input_image.size[1],
            num_inference_steps=steps,
            generator=torch.Generator().manual_seed(seed),
        )
        #.images[0]
        if isinstance(output, tuple):
            output_tensor = output[0]
        else:
            output_tensor = output

        output_tensor = output_tensor.squeeze(0).detach().cpu()
        h, w = original_size
        output_tensor = output_tensor[..., :h, :w]
        output_image = tensor_to_image(output_tensor)
    return output_image, seed, gr.Button(visible=True)

@spaces.GPU
def infer_example(input_image):
    image, seed, _ = infer(input_image)
    return image, seed

css="""
#col-container {
    margin: 0 auto;
    max-width: 960px;
}
"""

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# GenSIRR: Rectifying Latent Space for Generative SIRR 
        This is a demo for our generative single-image reflection removal model. To limit the running time, the model here runs at a 512px resolution. 
        We strongly suggest you to use 768px or 1024px for better performance. 
        """)
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Upload the image for reflection removal", type="pil")
                with gr.Row():
                    run_button = gr.Button("Run")
                
                with gr.Accordion("Advanced Settings", open=False):
                    
                    seed = gr.Slider(
                        label="Seed",
                        minimum=0,
                        maximum=MAX_SEED,
                        step=1,
                        value=0,
                    )
                    
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                    
                    steps = gr.Slider(
                        label="Steps",
                        minimum=1,
                        maximum=30,
                        value=28,
                        step=1
                    )
                    
            with gr.Column():
                result = gr.Image(label="Result", show_label=False, interactive=False)
                reuse_button = gr.Button("Reuse this image", visible=False)
        
            
            
    gr.on(
        triggers=[run_button.click],
        fn = infer,
        inputs = [input_image, seed, randomize_seed, steps],
        outputs = [result, seed, reuse_button]
    )
    reuse_button.click(
        fn = lambda image: image,
        inputs = [result],
        outputs = [input_image]
    )

demo.launch(mcp_server=True)