IdlecloudX commited on
Commit
cf41ddf
·
verified ·
1 Parent(s): 404bfc8

Upload 5 files

Browse files
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch 2.8 and dependencies (temporary hack)
2
+ import os
3
+ os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces peft')
4
+
5
+ # Actual demo code
6
+ import gradio as gr
7
+ import numpy as np
8
+ import spaces
9
+ import torch
10
+ import random
11
+ from PIL import Image
12
+
13
+ from diffusers import FluxKontextPipeline
14
+ from diffusers.utils import load_image
15
+
16
+ from optimization import optimize_pipeline_
17
+
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+
20
+ # 1. 加载基础模型
21
+ pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
22
+
23
+ # 2. 加载 LoRA
24
+ try:
25
+ pipe.load_lora_weights(".", weight_name="change_clothes_to_nothing_000012800.safetensors")
26
+ print("Successfully loaded LoRA weights from the root directory.")
27
+ except Exception as e:
28
+ print(f"Could not load LoRA weights. Please ensure 'change_clothes_to_nothing_000012800.safetensors' is in the root directory. Error: {e}")
29
+
30
+
31
+ # 3. 对加载了 LoRA 的模型进行优化
32
+ optimize_pipeline_(pipe, image=Image.new("RGB", (512, 512)), prompt='prompt')
33
+
34
+ @spaces.GPU
35
+ def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, lora_scale=1.0, progress=gr.Progress(track_tqdm=True)):
36
+ """
37
+ 使用 FLUX.1 Kontext pipeline 执行图像编辑。
38
+ """
39
+ if randomize_seed:
40
+ seed = random.randint(0, MAX_SEED)
41
+
42
+ if input_image:
43
+ input_image = input_image.convert("RGB")
44
+ image = pipe(
45
+ image=input_image,
46
+ prompt=prompt,
47
+ guidance_scale=guidance_scale,
48
+ width = input_image.size[0],
49
+ height = input_image.size[1],
50
+ num_inference_steps=steps,
51
+ generator=torch.Generator().manual_seed(seed),
52
+ cross_attention_kwargs={"scale": lora_scale}, # 应用 LoRA 强度
53
+ ).images[0]
54
+ else:
55
+ image = pipe(
56
+ prompt=prompt,
57
+ guidance_scale=guidance_scale,
58
+ num_inference_steps=steps,
59
+ generator=torch.Generator().manual_seed(seed),
60
+ cross_attention_kwargs={"scale": lora_scale}, # 应用 LoRA 强度
61
+ ).images[0]
62
+ return image, seed, gr.Button(visible=True)
63
+
64
+
65
+ css="""
66
+ #col-container {
67
+ margin: 0 auto;
68
+ max-width: 960px;
69
+ }
70
+ """
71
+
72
+ with gr.Blocks(css=css) as demo:
73
+
74
+ with gr.Column(elem_id="col-container"):
75
+ gr.Markdown(f"""# FLUX.1 Kontext [dev]
76
+ Image editing and manipulation model guidance-distilled from FLUX.1 Kontext [pro], [[blog]](https://bfl.ai/announcements/flux-1-kontext-dev) [[model]](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev)
77
+ """)
78
+ with gr.Row():
79
+ with gr.Column():
80
+ input_image = gr.Image(label="上传要编辑的图片", type="pil")
81
+ with gr.Row():
82
+ prompt = gr.Text(
83
+ label="Prompt",
84
+ show_label=False,
85
+ max_lines=1,
86
+ placeholder="输入您的编辑指令 (例如: '移除眼镜', '添加一顶帽子')",
87
+ container=False,
88
+ )
89
+ run_button = gr.Button("运行", scale=0)
90
+ with gr.Accordion("高级设置", open=False):
91
+
92
+ lora_scale = gr.Slider(
93
+ label="LoRA 强度 (LoRA Scale)",
94
+ minimum=0.0,
95
+ maximum=2.0,
96
+ step=0.05,
97
+ value=0.8,
98
+ )
99
+
100
+ seed = gr.Slider(
101
+ label="随机种子 (Seed)",
102
+ minimum=0,
103
+ maximum=MAX_SEED,
104
+ step=1,
105
+ value=0,
106
+ )
107
+
108
+ randomize_seed = gr.Checkbox(label="随机化种子 (Randomize seed)", value=True)
109
+
110
+ guidance_scale = gr.Slider(
111
+ label="引导系数 (Guidance Scale)",
112
+ minimum=1,
113
+ maximum=10,
114
+ step=0.1,
115
+ value=2.5,
116
+ )
117
+
118
+ steps = gr.Slider(
119
+ label="步数 (Steps)",
120
+ minimum=1,
121
+ maximum=30,
122
+ value=28,
123
+ step=1
124
+ )
125
+
126
+ with gr.Column():
127
+ result = gr.Image(label="结果", show_label=False, interactive=False)
128
+ reuse_button = gr.Button("复用此图", visible=False)
129
+
130
+ gr.on(
131
+ triggers=[run_button.click, prompt.submit],
132
+ fn = infer,
133
+ inputs = [input_image, prompt, seed, randomize_seed, guidance_scale, steps, lora_scale],
134
+ outputs = [result, seed, reuse_button]
135
+ )
136
+ reuse_button.click(
137
+ fn = lambda image: image,
138
+ inputs = [result],
139
+ outputs = [input_image]
140
+ )
141
+
142
+ demo.launch(mcp_server=True)
change_clothes_to_nothing_000012800.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e35dc275e9945f7907c6501c14da45a235efb1df2cd087a5a27cc03c168d3c50
3
+ size 343806408
optimization.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+
8
+ import spaces
9
+ import torch
10
+ from torch.utils._pytree import tree_map_only
11
+
12
+ from optimization_utils import capture_component_call
13
+ from optimization_utils import aoti_compile
14
+
15
+
16
+ P = ParamSpec('P')
17
+
18
+
19
+ TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=4096, max=8212)
20
+
21
+ TRANSFORMER_DYNAMIC_SHAPES = {
22
+ 'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
23
+ 'img_ids': {0: TRANSFORMER_HIDDEN_DIM},
24
+ }
25
+
26
+ INDUCTOR_CONFIGS = {
27
+ 'conv_1x1_as_mm': True,
28
+ 'epilogue_fusion': False,
29
+ 'coordinate_descent_tuning': True,
30
+ 'coordinate_descent_check_all_directions': True,
31
+ 'max_autotune': True,
32
+ 'triton.cudagraphs': True,
33
+ }
34
+
35
+
36
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
37
+
38
+ @spaces.GPU(duration=1500)
39
+ def compile_transformer():
40
+
41
+ with capture_component_call(pipeline, 'transformer') as call:
42
+ pipeline(*args, **kwargs)
43
+
44
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
45
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
46
+
47
+ pipeline.transformer.fuse_qkv_projections()
48
+
49
+ exported = torch.export.export(
50
+ mod=pipeline.transformer,
51
+ args=call.args,
52
+ kwargs=call.kwargs,
53
+ dynamic_shapes=dynamic_shapes,
54
+ )
55
+
56
+ return aoti_compile(exported, INDUCTOR_CONFIGS)
57
+
58
+ transformer_config = pipeline.transformer.config
59
+ pipeline.transformer = compile_transformer()
60
+ pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
optimization_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ import contextlib
4
+ from contextvars import ContextVar
5
+ from io import BytesIO
6
+ from typing import Any
7
+ from typing import cast
8
+ from unittest.mock import patch
9
+
10
+ import torch
11
+ from torch._inductor.package.package import package_aoti
12
+ from torch.export.pt2_archive._package import AOTICompiledModel
13
+ from torch.export.pt2_archive._package_weights import TensorProperties
14
+ from torch.export.pt2_archive._package_weights import Weights
15
+
16
+
17
+ INDUCTOR_CONFIGS_OVERRIDES = {
18
+ 'aot_inductor.package_constants_in_so': False,
19
+ 'aot_inductor.package_constants_on_disk': True,
20
+ 'aot_inductor.package': True,
21
+ }
22
+
23
+
24
+ class ZeroGPUCompiledModel:
25
+ def __init__(self, archive_file: torch.types.FileLike, weights: Weights, cuda: bool = False):
26
+ self.archive_file = archive_file
27
+ self.weights = weights
28
+ if cuda:
29
+ self.weights_to_cuda_()
30
+ self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
31
+ def weights_to_cuda_(self):
32
+ for name in self.weights:
33
+ tensor, properties = self.weights.get_weight(name)
34
+ self.weights[name] = (tensor.to('cuda'), properties)
35
+ def __call__(self, *args, **kwargs):
36
+ if (compiled_model := self.compiled_model.get()) is None:
37
+ constants_map = {name: value[0] for name, value in self.weights.items()}
38
+ compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
39
+ compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
40
+ self.compiled_model.set(compiled_model)
41
+ return compiled_model(*args, **kwargs)
42
+ def __reduce__(self):
43
+ weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]] = {}
44
+ for name in self.weights:
45
+ tensor, properties = self.weights.get_weight(name)
46
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
47
+ weight_dict[name] = (tensor_.copy_(tensor).detach().share_memory_(), properties)
48
+ return ZeroGPUCompiledModel, (self.archive_file, Weights(weight_dict), True)
49
+
50
+
51
+ def aoti_compile(
52
+ exported_program: torch.export.ExportedProgram,
53
+ inductor_configs: dict[str, Any] | None = None,
54
+ ):
55
+ inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
56
+ gm = cast(torch.fx.GraphModule, exported_program.module())
57
+ assert exported_program.example_inputs is not None
58
+ args, kwargs = exported_program.example_inputs
59
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
60
+ archive_file = BytesIO()
61
+ files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
62
+ package_aoti(archive_file, files)
63
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
64
+ return ZeroGPUCompiledModel(archive_file, weights)
65
+
66
+
67
+ @contextlib.contextmanager
68
+ def capture_component_call(
69
+ pipeline: Any,
70
+ component_name: str,
71
+ component_method='forward',
72
+ ):
73
+
74
+ class CapturedCallException(Exception):
75
+ def __init__(self, *args, **kwargs):
76
+ super().__init__()
77
+ self.args = args
78
+ self.kwargs = kwargs
79
+
80
+ class CapturedCall:
81
+ def __init__(self):
82
+ self.args: tuple[Any, ...] = ()
83
+ self.kwargs: dict[str, Any] = {}
84
+
85
+ component = getattr(pipeline, component_name)
86
+ captured_call = CapturedCall()
87
+
88
+ def capture_call(*args, **kwargs):
89
+ raise CapturedCallException(*args, **kwargs)
90
+
91
+ with patch.object(component, component_method, new=capture_call):
92
+ try:
93
+ yield captured_call
94
+ except CapturedCallException as e:
95
+ captured_call.args = e.args
96
+ captured_call.kwargs = e.kwargs
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers
2
+ git+https://github.com/huggingface/diffusers.git
3
+ accelerate
4
+ safetensors
5
+ sentencepiece
6
+ peft