Jonny001 commited on
Commit
6ed0c9b
·
verified ·
1 Parent(s): 0a49268

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +9 -7
  2. app.py +174 -0
  3. optimization.py +60 -0
  4. optimization_utils.py +96 -0
  5. requirements.txt +5 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Image Editor
3
- emoji: 🏆
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.44.1
8
  app_file: app.py
9
- pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Images Editor
3
+ emoji: 🖼
4
+ colorFrom: yellow
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.34.0
8
  app_file: app.py
9
+ pinned: true
10
+ license: mit
11
+ short_description: 'Image edit Using Text '
12
  ---
13
 
14
+ Credits to : [Black Forest Labs](https://huggingface.co/black-forest-labs)
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch 2.8 (temporary hack)
2
+ import os
3
+ os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
4
+
5
+ # Actual demo code
6
+ import gradio as gr
7
+ import numpy as np
8
+ import spaces
9
+ import torch
10
+ import random
11
+ from PIL import Image
12
+
13
+ from diffusers import FluxKontextPipeline
14
+ from diffusers.utils import load_image
15
+
16
+ from optimization import optimize_pipeline_
17
+
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+
20
+ pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
21
+ optimize_pipeline_(pipe, image=Image.new("RGB", (512, 512)), prompt='prompt')
22
+
23
+ @spaces.GPU
24
+ def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
25
+ """
26
+ Perform image editing using the FLUX.1 Kontext pipeline.
27
+
28
+ This function takes an input image and a text prompt to generate a modified version
29
+ of the image based on the provided instructions. It uses the FLUX.1 Kontext model
30
+ for contextual image editing tasks.
31
+
32
+ Args:
33
+ input_image (PIL.Image.Image): The input image to be edited. Will be converted
34
+ to RGB format if not already in that format.
35
+ prompt (str): Text description of the desired edit to apply to the image.
36
+ Examples: "Remove glasses", "Add a hat", "Change background to beach".
37
+ seed (int, optional): Random seed for reproducible generation. Defaults to 42.
38
+ Must be between 0 and MAX_SEED (2^31 - 1).
39
+ randomize_seed (bool, optional): If True, generates a random seed instead of
40
+ using the provided seed value. Defaults to False.
41
+ guidance_scale (float, optional): Controls how closely the model follows the
42
+ prompt. Higher values mean stronger adherence to the prompt but may reduce
43
+ image quality. Range: 1.0-10.0. Defaults to 2.5.
44
+ steps (int, optional): Controls how many steps to run the diffusion model for.
45
+ Range: 1-30. Defaults to 28.
46
+ progress (gr.Progress, optional): Gradio progress tracker for monitoring
47
+ generation progress. Defaults to gr.Progress(track_tqdm=True).
48
+
49
+ Returns:
50
+ tuple: A 3-tuple containing:
51
+ - PIL.Image.Image: The generated/edited image
52
+ - int: The seed value used for generation (useful when randomize_seed=True)
53
+ - gr.update: Gradio update object to make the reuse button visible
54
+
55
+ Example:
56
+ >>> edited_image, used_seed, button_update = infer(
57
+ ... input_image=my_image,
58
+ ... prompt="Add sunglasses",
59
+ ... seed=123,
60
+ ... randomize_seed=False,
61
+ ... guidance_scale=2.5
62
+ ... )
63
+ """
64
+ if randomize_seed:
65
+ seed = random.randint(0, MAX_SEED)
66
+
67
+ if input_image:
68
+ input_image = input_image.convert("RGB")
69
+ image = pipe(
70
+ image=input_image,
71
+ prompt=prompt,
72
+ guidance_scale=guidance_scale,
73
+ width = input_image.size[0],
74
+ height = input_image.size[1],
75
+ num_inference_steps=steps,
76
+ generator=torch.Generator().manual_seed(seed),
77
+ ).images[0]
78
+ else:
79
+ image = pipe(
80
+ prompt=prompt,
81
+ guidance_scale=guidance_scale,
82
+ num_inference_steps=steps,
83
+ generator=torch.Generator().manual_seed(seed),
84
+ ).images[0]
85
+ return image, seed, gr.Button(visible=True)
86
+
87
+ @spaces.GPU
88
+ def infer_example(input_image, prompt):
89
+ image, seed, _ = infer(input_image, prompt)
90
+ return image, seed
91
+
92
+ css="""
93
+ #col-container {
94
+ margin: 0 auto;
95
+ max-width: 960px;
96
+ }
97
+ """
98
+
99
+ with gr.Blocks(css=css) as demo:
100
+
101
+ with gr.Column(elem_id="col-container"):
102
+ gr.Markdown(f"""# FLUX.1 Kontext [dev]
103
+ Image editing and manipulation model guidance-distilled from FLUX.1 Kontext [pro], [[blog]](https://bfl.ai/announcements/flux-1-kontext-dev) [[model]](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev)
104
+ """)
105
+ with gr.Row():
106
+ with gr.Column():
107
+ input_image = gr.Image(label="Upload the image for editing", type="pil")
108
+ with gr.Row():
109
+ prompt = gr.Text(
110
+ label="Prompt",
111
+ show_label=False,
112
+ max_lines=1,
113
+ placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')",
114
+ container=False,
115
+ )
116
+ run_button = gr.Button("Run", scale=0)
117
+ with gr.Accordion("Advanced Settings", open=False):
118
+
119
+ seed = gr.Slider(
120
+ label="Seed",
121
+ minimum=0,
122
+ maximum=MAX_SEED,
123
+ step=1,
124
+ value=0,
125
+ )
126
+
127
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
128
+
129
+ guidance_scale = gr.Slider(
130
+ label="Guidance Scale",
131
+ minimum=1,
132
+ maximum=10,
133
+ step=0.1,
134
+ value=2.5,
135
+ )
136
+
137
+ steps = gr.Slider(
138
+ label="Steps",
139
+ minimum=1,
140
+ maximum=30,
141
+ value=28,
142
+ step=1
143
+ )
144
+
145
+ with gr.Column():
146
+ result = gr.Image(label="Result", show_label=False, interactive=False)
147
+ reuse_button = gr.Button("Reuse this image", visible=False)
148
+
149
+
150
+ examples = gr.Examples(
151
+ examples=[
152
+ ["flowers.png", "turn the flowers into sunflowers"],
153
+ ["monster.png", "make this monster ride a skateboard on the beach"],
154
+ ["cat.png", "make this cat happy"]
155
+ ],
156
+ inputs=[input_image, prompt],
157
+ outputs=[result, seed],
158
+ fn=infer_example,
159
+ cache_examples="lazy"
160
+ )
161
+
162
+ gr.on(
163
+ triggers=[run_button.click, prompt.submit],
164
+ fn = infer,
165
+ inputs = [input_image, prompt, seed, randomize_seed, guidance_scale, steps],
166
+ outputs = [result, seed, reuse_button]
167
+ )
168
+ reuse_button.click(
169
+ fn = lambda image: image,
170
+ inputs = [result],
171
+ outputs = [input_image]
172
+ )
173
+
174
+ demo.launch(mcp_server=True)
optimization.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+
8
+ import spaces
9
+ import torch
10
+ from torch.utils._pytree import tree_map_only
11
+
12
+ from optimization_utils import capture_component_call
13
+ from optimization_utils import aoti_compile
14
+
15
+
16
+ P = ParamSpec('P')
17
+
18
+
19
+ TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=4096, max=8212)
20
+
21
+ TRANSFORMER_DYNAMIC_SHAPES = {
22
+ 'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
23
+ 'img_ids': {0: TRANSFORMER_HIDDEN_DIM},
24
+ }
25
+
26
+ INDUCTOR_CONFIGS = {
27
+ 'conv_1x1_as_mm': True,
28
+ 'epilogue_fusion': False,
29
+ 'coordinate_descent_tuning': True,
30
+ 'coordinate_descent_check_all_directions': True,
31
+ 'max_autotune': True,
32
+ 'triton.cudagraphs': True,
33
+ }
34
+
35
+
36
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
37
+
38
+ @spaces.GPU(duration=1500)
39
+ def compile_transformer():
40
+
41
+ with capture_component_call(pipeline, 'transformer') as call:
42
+ pipeline(*args, **kwargs)
43
+
44
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
45
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
46
+
47
+ pipeline.transformer.fuse_qkv_projections()
48
+
49
+ exported = torch.export.export(
50
+ mod=pipeline.transformer,
51
+ args=call.args,
52
+ kwargs=call.kwargs,
53
+ dynamic_shapes=dynamic_shapes,
54
+ )
55
+
56
+ return aoti_compile(exported, INDUCTOR_CONFIGS)
57
+
58
+ transformer_config = pipeline.transformer.config
59
+ pipeline.transformer = compile_transformer()
60
+ pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
optimization_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ import contextlib
4
+ from contextvars import ContextVar
5
+ from io import BytesIO
6
+ from typing import Any
7
+ from typing import cast
8
+ from unittest.mock import patch
9
+
10
+ import torch
11
+ from torch._inductor.package.package import package_aoti
12
+ from torch.export.pt2_archive._package import AOTICompiledModel
13
+ from torch.export.pt2_archive._package_weights import TensorProperties
14
+ from torch.export.pt2_archive._package_weights import Weights
15
+
16
+
17
+ INDUCTOR_CONFIGS_OVERRIDES = {
18
+ 'aot_inductor.package_constants_in_so': False,
19
+ 'aot_inductor.package_constants_on_disk': True,
20
+ 'aot_inductor.package': True,
21
+ }
22
+
23
+
24
+ class ZeroGPUCompiledModel:
25
+ def __init__(self, archive_file: torch.types.FileLike, weights: Weights, cuda: bool = False):
26
+ self.archive_file = archive_file
27
+ self.weights = weights
28
+ if cuda:
29
+ self.weights_to_cuda_()
30
+ self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
31
+ def weights_to_cuda_(self):
32
+ for name in self.weights:
33
+ tensor, properties = self.weights.get_weight(name)
34
+ self.weights[name] = (tensor.to('cuda'), properties)
35
+ def __call__(self, *args, **kwargs):
36
+ if (compiled_model := self.compiled_model.get()) is None:
37
+ constants_map = {name: value[0] for name, value in self.weights.items()}
38
+ compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
39
+ compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
40
+ self.compiled_model.set(compiled_model)
41
+ return compiled_model(*args, **kwargs)
42
+ def __reduce__(self):
43
+ weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]] = {}
44
+ for name in self.weights:
45
+ tensor, properties = self.weights.get_weight(name)
46
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
47
+ weight_dict[name] = (tensor_.copy_(tensor).detach().share_memory_(), properties)
48
+ return ZeroGPUCompiledModel, (self.archive_file, Weights(weight_dict), True)
49
+
50
+
51
+ def aoti_compile(
52
+ exported_program: torch.export.ExportedProgram,
53
+ inductor_configs: dict[str, Any] | None = None,
54
+ ):
55
+ inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
56
+ gm = cast(torch.fx.GraphModule, exported_program.module())
57
+ assert exported_program.example_inputs is not None
58
+ args, kwargs = exported_program.example_inputs
59
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
60
+ archive_file = BytesIO()
61
+ files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
62
+ package_aoti(archive_file, files)
63
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
64
+ return ZeroGPUCompiledModel(archive_file, weights)
65
+
66
+
67
+ @contextlib.contextmanager
68
+ def capture_component_call(
69
+ pipeline: Any,
70
+ component_name: str,
71
+ component_method='forward',
72
+ ):
73
+
74
+ class CapturedCallException(Exception):
75
+ def __init__(self, *args, **kwargs):
76
+ super().__init__()
77
+ self.args = args
78
+ self.kwargs = kwargs
79
+
80
+ class CapturedCall:
81
+ def __init__(self):
82
+ self.args: tuple[Any, ...] = ()
83
+ self.kwargs: dict[str, Any] = {}
84
+
85
+ component = getattr(pipeline, component_name)
86
+ captured_call = CapturedCall()
87
+
88
+ def capture_call(*args, **kwargs):
89
+ raise CapturedCallException(*args, **kwargs)
90
+
91
+ with patch.object(component, component_method, new=capture_call):
92
+ try:
93
+ yield captured_call
94
+ except CapturedCallException as e:
95
+ captured_call.args = e.args
96
+ captured_call.kwargs = e.kwargs
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ git+https://github.com/huggingface/diffusers.git
3
+ accelerate
4
+ safetensors
5
+ sentencepiece