Flulike99 commited on
Commit
81bf056
·
1 Parent(s): 51935a3
.gitattributes copy ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README copy.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Ads Ap
3
+ emoji: 📉
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: Advertisement generation
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import Union, Any, Optional
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ import spaces
10
+
11
+ # 添加项目根目录到Python路径
12
+ project_root = os.path.dirname(os.path.abspath(__file__))
13
+ sys.path.append(project_root)
14
+ hf_token = os.environ.get("CASCADE_PRIVATE_MODEL_HF_TOKEN")
15
+ secret_model = os.environ.get("MODEL_PATH")
16
+
17
+ try:
18
+ from diffusers import FluxTransformer2DModel
19
+ from diffusers.pipelines import FluxPipeline
20
+ from flux.condition import Condition
21
+ from flux.generate import generate
22
+ from flux.lora_controller import set_lora_scale
23
+ FLUX_AVAILABLE = True
24
+ except ImportError as e:
25
+ print(f"Warning: FLUX components not available: {e}")
26
+ FLUX_AVAILABLE = False
27
+
28
+ from huggingface_hub import hf_hub_download
29
+ from safetensors.torch import load_file
30
+
31
+ # 認証トークンを使ってファイルをダウンロード
32
+ model_path = hf_hub_download(
33
+ repo_id="spaces/Cascade-Inc/private_model",
34
+ filename=secret_model,
35
+ token=hf_token,
36
+ repo_type="space"
37
+ )
38
+
39
+ # Get temp directory
40
+ temp_dir = os.path.join(os.path.expanduser("~"), "gradio_temp")
41
+ os.makedirs(temp_dir, exist_ok=True)
42
+ os.environ["GRADIO_TEMP_DIR"] = temp_dir
43
+
44
+ # Global state
45
+ pipe: Union[FluxPipeline, None] = None
46
+ use_int8 = False
47
+
48
+ ADAPTER_NAME = "subject"
49
+ MODEL_PATH = model_path
50
+
51
+ def get_gpu_memory_gb() -> float:
52
+ return torch.cuda.get_device_properties(0).total_memory / 1024**3
53
+
54
+ def init_pipeline_if_needed():
55
+ global pipe
56
+ if pipe is not None:
57
+ return
58
+
59
+ if use_int8 or get_gpu_memory_gb() < 33:
60
+ transformer_model = FluxTransformer2DModel.from_pretrained(
61
+ "sayakpaul/flux.1-schell-int8wo-improved",
62
+ torch_dtype=torch.bfloat16,
63
+ use_safetensors=False,
64
+ )
65
+ _pipe = FluxPipeline.from_pretrained(
66
+ "black-forest-labs/FLUX.1-schnell",
67
+ transformer=transformer_model,
68
+ torch_dtype=torch.bfloat16,
69
+ )
70
+ else:
71
+ _pipe = FluxPipeline.from_pretrained(
72
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
73
+ )
74
+
75
+ _pipe = _pipe.to("cuda")
76
+ _pipe.load_lora_weights(MODEL_PATH, adapter_name=ADAPTER_NAME)
77
+ _pipe.set_adapters([ADAPTER_NAME])
78
+ pipe = _pipe
79
+
80
+ def _to_pil_rgba(img: Any) -> Image.Image:
81
+ """Convert various inputs to PIL RGBA image"""
82
+ pil: Optional[Image.Image] = None
83
+
84
+ if isinstance(img, Image.Image):
85
+ pil = img
86
+ elif isinstance(img, np.ndarray):
87
+ pil = Image.fromarray(img)
88
+ elif isinstance(img, str) and os.path.exists(img):
89
+ pil = Image.open(img)
90
+ else:
91
+ raise ValueError("Unsupported image type")
92
+
93
+ if pil.mode != "RGBA":
94
+ pil = pil.convert("RGBA")
95
+ return pil
96
+
97
+ def _place_subject_on_canvas(
98
+ subject_rgba: Image.Image,
99
+ canvas_size: int,
100
+ style: str,
101
+ base_coverage: float = 0.7,
102
+ ) -> Image.Image:
103
+ """
104
+ Place subject on transparent canvas with position and angle adjustments based on style
105
+ """
106
+ canvas = Image.new("RGBA", (canvas_size, canvas_size), (0, 0, 0, 0))
107
+
108
+ # Define three styles
109
+ styles = {
110
+ "center": {"scale": 1.0, "rotation": 0, "pos": (0.0, 0.0)},
111
+ "tilt_left": {"scale": 0.95, "rotation": -15, "pos": (-0.1, 0.0)},
112
+ "right": {"scale": 0.95, "rotation": 0, "pos": (0.25, 0.0)},
113
+ }
114
+
115
+ if style not in styles:
116
+ style = "center"
117
+
118
+ style_config = styles[style]
119
+
120
+ # Calculate scaling
121
+ subject_w, subject_h = subject_rgba.size
122
+ max_dim = max(subject_w, subject_h)
123
+ desired_max_dim = max(1, int(canvas_size * base_coverage * style_config["scale"]))
124
+ scale = desired_max_dim / max(1, max_dim)
125
+ new_w = max(1, int(subject_w * scale))
126
+ new_h = max(1, int(subject_h * scale))
127
+ resized = subject_rgba.resize((new_w, new_h), Image.LANCZOS)
128
+
129
+ # Rotation
130
+ rotated = resized.rotate(style_config["rotation"], expand=True, resample=Image.BICUBIC)
131
+ rw, rh = rotated.size
132
+
133
+ # Positioning
134
+ cx = canvas_size // 2
135
+ cy = canvas_size // 2
136
+ dx = int(style_config["pos"][0] * canvas_size)
137
+ dy = int(style_config["pos"][1] * canvas_size)
138
+
139
+ paste_x = int(cx + dx - rw // 2)
140
+ paste_y = int(cy + dy - rh // 2)
141
+
142
+ canvas.alpha_composite(rotated, dest=(paste_x, paste_y))
143
+ return canvas
144
+
145
+ def _place_subject_on_canvas_rect(
146
+ subject_rgba: Image.Image,
147
+ canvas_width: int,
148
+ canvas_height: int,
149
+ style: str,
150
+ base_coverage: float = 0.7,
151
+ ) -> Image.Image:
152
+ """
153
+ Place subject on rectangular transparent canvas with position and angle adjustments based on style
154
+ """
155
+ canvas = Image.new("RGBA", (canvas_width, canvas_height), (0, 0, 0, 0))
156
+
157
+ # Define three styles
158
+ styles = {
159
+ "center": {"scale": 1.0, "rotation": 0, "pos": (0.0, 0.0)},
160
+ "tilt_left": {"scale": 0.95, "rotation": -15, "pos": (-0.1, 0.0)},
161
+ "right": {"scale": 0.95, "rotation": 0, "pos": (0.25, 0.0)},
162
+ }
163
+
164
+ if style not in styles:
165
+ style = "center"
166
+
167
+ style_config = styles[style]
168
+
169
+ # Calculate scaling based on smaller dimension
170
+ subject_w, subject_h = subject_rgba.size
171
+ max_dim = max(subject_w, subject_h)
172
+ canvas_min_dim = min(canvas_width, canvas_height)
173
+ desired_max_dim = max(1, int(canvas_min_dim * base_coverage * style_config["scale"]))
174
+ scale = desired_max_dim / max(1, max_dim)
175
+ new_w = max(1, int(subject_w * scale))
176
+ new_h = max(1, int(subject_h * scale))
177
+ resized = subject_rgba.resize((new_w, new_h), Image.LANCZOS)
178
+
179
+ # Rotation
180
+ rotated = resized.rotate(style_config["rotation"], expand=True, resample=Image.BICUBIC)
181
+ rw, rh = rotated.size
182
+
183
+ # Positioning
184
+ cx = canvas_width // 2
185
+ cy = canvas_height // 2
186
+ dx = int(style_config["pos"][0] * canvas_width)
187
+ dy = int(style_config["pos"][1] * canvas_height)
188
+
189
+ paste_x = int(cx + dx - rw // 2)
190
+ paste_y = int(cy + dy - rh // 2)
191
+
192
+ canvas.alpha_composite(rotated, dest=(paste_x, paste_y))
193
+ return canvas
194
+
195
+ def apply_style(image: Image.Image, style: str, width: int = 1024, height: int = 1024) -> Image.Image:
196
+ """Apply specified style to image with custom dimensions"""
197
+ if image is None:
198
+ # Create default transparent image
199
+ image = Image.new("RGBA", (512, 512), (255, 255, 255, 0))
200
+
201
+ # Ensure image is in RGBA format
202
+ if image.mode != "RGBA":
203
+ image = image.convert("RGBA")
204
+
205
+ # Apply style with custom dimensions
206
+ styled_image = _place_subject_on_canvas_rect(image, width, height, style)
207
+ return styled_image
208
+
209
+ def generate_background_local(styled_image: Image.Image, prompt: str, steps: int = 10, width: int = 1024, height: int = 1024) -> Image.Image:
210
+ """Generate background using local FLUX model"""
211
+ if not FLUX_AVAILABLE:
212
+ # Return a simple gradient background if FLUX is not available
213
+ if styled_image is None:
214
+ return Image.new("RGB", (width, height), (200, 200, 255))
215
+ # Create a simple colored background
216
+ bg = Image.new("RGB", (width, height), (200, 220, 255))
217
+ if styled_image.mode == "RGBA":
218
+ bg.paste(styled_image, (0, 0), styled_image)
219
+ else:
220
+ bg.paste(styled_image, (0, 0))
221
+ return bg
222
+
223
+ init_pipeline_if_needed()
224
+
225
+ if styled_image is None:
226
+ return Image.new("RGB", (width, height), (255, 255, 255))
227
+
228
+ # Convert to RGB for background generation
229
+ img_rgb = styled_image.convert("RGB")
230
+
231
+ condition = Condition(ADAPTER_NAME, img_rgb, position_delta=(0, 0))
232
+
233
+ # Enable padding token orthogonalization for enhanced text-image alignment
234
+ model_config = {
235
+ 'padding_orthogonalization_enabled': True,
236
+ 'preserve_norm': True,
237
+ 'orthogonalize_all_tokens': False,
238
+ }
239
+
240
+ with set_lora_scale([ADAPTER_NAME], scale=3.0):
241
+ result_img = generate(
242
+ pipe,
243
+ model_config=model_config,
244
+ prompt=prompt.strip() if prompt else "",
245
+ conditions=[condition],
246
+ num_inference_steps=steps,
247
+ height=height,
248
+ width=width,
249
+ default_lora=True,
250
+ ).images[0]
251
+
252
+ return result_img
253
+
254
+ @spaces.GPU
255
+ # Gradio Interface
256
+ def create_simple_app():
257
+ # Example prompts for reference
258
+ example_prompts = [
259
+ {
260
+ "title": "Handcrafted Leather Wallet",
261
+ "prompt": "A hand-stitched, dark brown leather wallet lies half-open on a wooden desk with a map, next to a brass pen and compass. A stack of classic books is in the background. A warm desk lamp from the right highlights the leather texture. Classic, rustic style."
262
+ },
263
+ {
264
+ "title": "Sparkling Water with Fresh Lemons",
265
+ "prompt": "A dewy glass bottle of sparkling water on a white marble countertop, next to a sliced lemon and ice cubes. The background is a soft-focus, pale blue gradient. Lighting is bright, even, and cool-toned from above. Clean, crisp, minimalist style."
266
+ },
267
+ {
268
+ "title": "High-tech Smartwatch",
269
+ "prompt": "A titanium smartwatch with an illuminated screen rests on a black matte slate rock. The background is a blurred cityscape at night with neon bokeh. A sharp, direct light from the top-left highlights the watch's metallic edge. Futuristic, tech-focused style."
270
+ },
271
+ {
272
+ "title": "Japanese Ramen Bowl",
273
+ "prompt": "A ceramic bowl of tonkotsu ramen with chashu pork and a soft-boiled egg on a wooden table, with chopsticks beside it. Rising steam is caught in soft overhead light. The background is a blurred, cozy izakaya. Warm, authentic, and appetizing style."
274
+ },
275
+ {
276
+ "title": "Japanese Peach Iced Tea",
277
+ "prompt": "A bottle of Japanese peach iced tea beside a tall glass with tea and sparkling ice cubes. The background is a soft, warm peach and beige gradient. Lit with bright, soft light to appear crisp and refreshing. The style is clean, minimalist, and refined."
278
+ }
279
+ ]
280
+
281
+ with gr.Blocks(title="Ads Background Generation") as app:
282
+ gr.Markdown("# Ads Background Generation App")
283
+ gr.Markdown("Upload an image with transparent background → Enter prompt → Generate")
284
+
285
+ # Example Prompts Section
286
+ with gr.Accordion("📝 Example Prompts (Click to expand)", open=False):
287
+ gr.Markdown("### Background Prompt Examples")
288
+ gr.Markdown("Click any example below to copy it to the background description field:")
289
+
290
+ # Create example buttons
291
+ example_buttons = []
292
+ with gr.Row():
293
+ for i, example in enumerate(example_prompts):
294
+ if i < 3: # First row
295
+ example_btn = gr.Button(
296
+ f"📋 {example['title']}",
297
+ variant="secondary",
298
+ size="sm"
299
+ )
300
+ example_buttons.append(example_btn)
301
+
302
+ with gr.Row():
303
+ for i, example in enumerate(example_prompts):
304
+ if i >= 3: # Second row
305
+ example_btn = gr.Button(
306
+ f"📋 {example['title']}",
307
+ variant="secondary",
308
+ size="sm"
309
+ )
310
+ example_buttons.append(example_btn)
311
+
312
+ # Display area for selected prompt preview
313
+ selected_prompt_display = gr.Textbox(
314
+ label="Selected Prompt Preview",
315
+ lines=4,
316
+ max_lines=8,
317
+ interactive=False,
318
+ visible=False
319
+ )
320
+
321
+ with gr.Row():
322
+ # Left column
323
+ with gr.Column(scale=1):
324
+ # Image upload (top left)
325
+ input_image = gr.Image(
326
+ label="Upload Image (Transparent Background)",
327
+ type="pil",
328
+ format="png",
329
+ image_mode="RGBA",
330
+ height=350
331
+ )
332
+
333
+ # Image dimensions
334
+ with gr.Row():
335
+ img_width = gr.Number(
336
+ value=1024,
337
+ label="Width",
338
+ precision=0,
339
+ minimum=256,
340
+ maximum=2048
341
+ )
342
+ img_height = gr.Number(
343
+ value=1024,
344
+ label="Height",
345
+ precision=0,
346
+ minimum=256,
347
+ maximum=2048
348
+ )
349
+
350
+ # Background prompt (bottom left)
351
+ bg_prompt = gr.Textbox(
352
+ label="Background Description",
353
+ placeholder="e.g.: Forest scene, soft lighting",
354
+ lines=3
355
+ )
356
+
357
+ # Generation steps
358
+ steps_slider = gr.Slider(
359
+ minimum=5,
360
+ maximum=20,
361
+ value=10,
362
+ step=1,
363
+ label="Generation Steps"
364
+ )
365
+
366
+ # Generate background button
367
+ generate_bg_btn = gr.Button("Generate Background", variant="primary", size="lg")
368
+
369
+ # Right column - Result display
370
+ with gr.Column(scale=1):
371
+ final_result = gr.Image(
372
+ label="Generated Result",
373
+ type="pil",
374
+ format="png",
375
+ height=700
376
+ )
377
+
378
+ # Generate background directly from input image
379
+ def generate_from_input(image, prompt, steps, width, height):
380
+ if image is None:
381
+ return None
382
+
383
+ # Ensure image is RGBA
384
+ if image.mode != "RGBA":
385
+ image = image.convert("RGBA")
386
+
387
+ # Generate background using local model only
388
+ return generate_background_local(image, prompt, steps, width, height)
389
+
390
+ # Event binding
391
+ generate_bg_btn.click(
392
+ fn=generate_from_input,
393
+ inputs=[input_image, bg_prompt, steps_slider, img_width, img_height],
394
+ outputs=[final_result]
395
+ )
396
+
397
+ # Example prompt button handlers
398
+ def create_example_handler(prompt_text):
399
+ def handler():
400
+ return prompt_text, gr.update(value=prompt_text, visible=True)
401
+ return handler
402
+
403
+ # Connect example buttons to background prompt field and preview
404
+ for i, example_btn in enumerate(example_buttons):
405
+ if i < len(example_prompts):
406
+ example_btn.click(
407
+ fn=create_example_handler(example_prompts[i]['prompt']),
408
+ outputs=[bg_prompt, selected_prompt_display]
409
+ )
410
+
411
+ return app
412
+
413
+ if __name__ == "__main__":
414
+ app = create_simple_app()
415
+ app.launch(
416
+ debug=True,
417
+ share=False,
418
+ server_name="0.0.0.0",
419
+ server_port=7860
420
+ )
flux/__init__.py ADDED
File without changes
flux/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (134 Bytes). View file
 
flux/__pycache__/block.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
flux/__pycache__/condition.cpython-312.pyc ADDED
Binary file (5.74 kB). View file
 
flux/__pycache__/generate.cpython-312.pyc ADDED
Binary file (12.8 kB). View file
 
flux/__pycache__/lora_controller.cpython-312.pyc ADDED
Binary file (4.22 kB). View file
 
flux/__pycache__/padding_orthogonalization.cpython-312.pyc ADDED
Binary file (9.47 kB). View file
 
flux/__pycache__/pipeline_tools.cpython-312.pyc ADDED
Binary file (3.43 kB). View file
 
flux/__pycache__/transformer.cpython-312.pyc ADDED
Binary file (7.27 kB). View file
 
flux/block.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Union, Optional, Dict, Any, Callable
3
+ from diffusers.models.attention_processor import Attention, F
4
+ from .lora_controller import enable_lora
5
+
6
+
7
+ def attn_forward(
8
+ attn: Attention,
9
+ hidden_states: torch.FloatTensor,
10
+ encoder_hidden_states: torch.FloatTensor = None,
11
+ condition_latents: torch.FloatTensor = None,
12
+ attention_mask: Optional[torch.FloatTensor] = None,
13
+ image_rotary_emb: Optional[torch.Tensor] = None,
14
+ cond_rotary_emb: Optional[torch.Tensor] = None,
15
+ model_config: Optional[Dict[str, Any]] = {},
16
+ ) -> torch.FloatTensor:
17
+ batch_size, _, _ = (
18
+ hidden_states.shape
19
+ if encoder_hidden_states is None
20
+ else encoder_hidden_states.shape
21
+ )
22
+
23
+ with enable_lora(
24
+ (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
25
+ ):
26
+ # `sample` projections.
27
+ query = attn.to_q(hidden_states)
28
+ key = attn.to_k(hidden_states)
29
+ value = attn.to_v(hidden_states)
30
+
31
+ inner_dim = key.shape[-1]
32
+ head_dim = inner_dim // attn.heads
33
+
34
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
35
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
36
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
37
+
38
+ if attn.norm_q is not None:
39
+ query = attn.norm_q(query)
40
+ if attn.norm_k is not None:
41
+ key = attn.norm_k(key)
42
+
43
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
44
+ if encoder_hidden_states is not None:
45
+ # `context` projections.
46
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
47
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
48
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
49
+
50
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
51
+ batch_size, -1, attn.heads, head_dim
52
+ ).transpose(1, 2)
53
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
54
+ batch_size, -1, attn.heads, head_dim
55
+ ).transpose(1, 2)
56
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
57
+ batch_size, -1, attn.heads, head_dim
58
+ ).transpose(1, 2)
59
+
60
+ if attn.norm_added_q is not None:
61
+ encoder_hidden_states_query_proj = attn.norm_added_q(
62
+ encoder_hidden_states_query_proj
63
+ )
64
+ if attn.norm_added_k is not None:
65
+ encoder_hidden_states_key_proj = attn.norm_added_k(
66
+ encoder_hidden_states_key_proj
67
+ )
68
+
69
+ # attention
70
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
71
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
72
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
73
+
74
+ if image_rotary_emb is not None:
75
+ from diffusers.models.embeddings import apply_rotary_emb
76
+
77
+ query = apply_rotary_emb(query, image_rotary_emb)
78
+ key = apply_rotary_emb(key, image_rotary_emb)
79
+
80
+ if condition_latents is not None:
81
+ cond_query = attn.to_q(condition_latents)
82
+ cond_key = attn.to_k(condition_latents)
83
+ cond_value = attn.to_v(condition_latents)
84
+
85
+ cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
86
+ 1, 2
87
+ )
88
+ cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
89
+ cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
90
+ 1, 2
91
+ )
92
+ if attn.norm_q is not None:
93
+ cond_query = attn.norm_q(cond_query)
94
+ if attn.norm_k is not None:
95
+ cond_key = attn.norm_k(cond_key)
96
+
97
+ if cond_rotary_emb is not None:
98
+ cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
99
+ cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
100
+
101
+ if condition_latents is not None:
102
+ query = torch.cat([query, cond_query], dim=2)
103
+ key = torch.cat([key, cond_key], dim=2)
104
+ value = torch.cat([value, cond_value], dim=2)
105
+
106
+ if not model_config.get("union_cond_attn", True):
107
+ # If we don't want to use the union condition attention, we need to mask the attention
108
+ # between the hidden states and the condition latents
109
+ attention_mask = torch.ones(
110
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
111
+ )
112
+ condition_n = cond_query.shape[2]
113
+ attention_mask[-condition_n:, :-condition_n] = False
114
+ attention_mask[:-condition_n, -condition_n:] = False
115
+ elif model_config.get("independent_condition", False):
116
+ attention_mask = torch.ones(
117
+ query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
118
+ )
119
+ condition_n = cond_query.shape[2]
120
+ attention_mask[-condition_n:, :-condition_n] = False
121
+ if hasattr(attn, "c_factor"):
122
+ attention_mask = torch.zeros(
123
+ query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
124
+ )
125
+ condition_n = cond_query.shape[2]
126
+ bias = torch.log(attn.c_factor[0])
127
+ attention_mask[-condition_n:, :-condition_n] = bias
128
+ attention_mask[:-condition_n, -condition_n:] = bias
129
+ hidden_states = F.scaled_dot_product_attention(
130
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
131
+ )
132
+ hidden_states = hidden_states.transpose(1, 2).reshape(
133
+ batch_size, -1, attn.heads * head_dim
134
+ )
135
+ hidden_states = hidden_states.to(query.dtype)
136
+
137
+ if encoder_hidden_states is not None:
138
+ if condition_latents is not None:
139
+ encoder_hidden_states, hidden_states, condition_latents = (
140
+ hidden_states[:, : encoder_hidden_states.shape[1]],
141
+ hidden_states[
142
+ :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
143
+ ],
144
+ hidden_states[:, -condition_latents.shape[1] :],
145
+ )
146
+ else:
147
+ encoder_hidden_states, hidden_states = (
148
+ hidden_states[:, : encoder_hidden_states.shape[1]],
149
+ hidden_states[:, encoder_hidden_states.shape[1] :],
150
+ )
151
+
152
+ with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
153
+ # linear proj
154
+ hidden_states = attn.to_out[0](hidden_states)
155
+ # dropout
156
+ hidden_states = attn.to_out[1](hidden_states)
157
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
158
+
159
+ if condition_latents is not None:
160
+ condition_latents = attn.to_out[0](condition_latents)
161
+ condition_latents = attn.to_out[1](condition_latents)
162
+
163
+ return (
164
+ (hidden_states, encoder_hidden_states, condition_latents)
165
+ if condition_latents is not None
166
+ else (hidden_states, encoder_hidden_states)
167
+ )
168
+ elif condition_latents is not None:
169
+ # if there are condition_latents, we need to separate the hidden_states and the condition_latents
170
+ hidden_states, condition_latents = (
171
+ hidden_states[:, : -condition_latents.shape[1]],
172
+ hidden_states[:, -condition_latents.shape[1] :],
173
+ )
174
+ return hidden_states, condition_latents
175
+ else:
176
+ return hidden_states
177
+
178
+
179
+ def block_forward(
180
+ self,
181
+ hidden_states: torch.FloatTensor,
182
+ encoder_hidden_states: torch.FloatTensor,
183
+ condition_latents: torch.FloatTensor,
184
+ temb: torch.FloatTensor,
185
+ cond_temb: torch.FloatTensor,
186
+ cond_rotary_emb=None,
187
+ image_rotary_emb=None,
188
+ model_config: Optional[Dict[str, Any]] = {},
189
+ ):
190
+ use_cond = condition_latents is not None
191
+ with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
192
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
193
+ hidden_states, emb=temb
194
+ )
195
+
196
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
197
+ self.norm1_context(encoder_hidden_states, emb=temb)
198
+ )
199
+
200
+ if use_cond:
201
+ (
202
+ norm_condition_latents,
203
+ cond_gate_msa,
204
+ cond_shift_mlp,
205
+ cond_scale_mlp,
206
+ cond_gate_mlp,
207
+ ) = self.norm1(condition_latents, emb=cond_temb)
208
+
209
+ # Attention.
210
+ result = attn_forward(
211
+ self.attn,
212
+ model_config=model_config,
213
+ hidden_states=norm_hidden_states,
214
+ encoder_hidden_states=norm_encoder_hidden_states,
215
+ condition_latents=norm_condition_latents if use_cond else None,
216
+ image_rotary_emb=image_rotary_emb,
217
+ cond_rotary_emb=cond_rotary_emb if use_cond else None,
218
+ )
219
+ attn_output, context_attn_output = result[:2]
220
+ cond_attn_output = result[2] if use_cond else None
221
+
222
+ # Process attention outputs for the `hidden_states`.
223
+ # 1. hidden_states
224
+ attn_output = gate_msa.unsqueeze(1) * attn_output
225
+ hidden_states = hidden_states + attn_output
226
+ # 2. encoder_hidden_states
227
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
228
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
229
+ # 3. condition_latents
230
+ if use_cond:
231
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
232
+ condition_latents = condition_latents + cond_attn_output
233
+ if model_config.get("add_cond_attn", False):
234
+ hidden_states += cond_attn_output
235
+
236
+ # LayerNorm + MLP.
237
+ # 1. hidden_states
238
+ norm_hidden_states = self.norm2(hidden_states)
239
+ norm_hidden_states = (
240
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
241
+ )
242
+ # 2. encoder_hidden_states
243
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
244
+ norm_encoder_hidden_states = (
245
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
246
+ )
247
+ # 3. condition_latents
248
+ if use_cond:
249
+ norm_condition_latents = self.norm2(condition_latents)
250
+ norm_condition_latents = (
251
+ norm_condition_latents * (1 + cond_scale_mlp[:, None])
252
+ + cond_shift_mlp[:, None]
253
+ )
254
+
255
+ # Feed-forward.
256
+ with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
257
+ # 1. hidden_states
258
+ ff_output = self.ff(norm_hidden_states)
259
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
260
+ # 2. encoder_hidden_states
261
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
262
+ context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
263
+ # 3. condition_latents
264
+ if use_cond:
265
+ cond_ff_output = self.ff(norm_condition_latents)
266
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
267
+
268
+ # Process feed-forward outputs.
269
+ hidden_states = hidden_states + ff_output
270
+ encoder_hidden_states = encoder_hidden_states + context_ff_output
271
+ if use_cond:
272
+ condition_latents = condition_latents + cond_ff_output
273
+
274
+ # Clip to avoid overflow.
275
+ if encoder_hidden_states.dtype == torch.float16:
276
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
277
+
278
+ return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
279
+
280
+
281
+ def single_block_forward(
282
+ self,
283
+ hidden_states: torch.FloatTensor,
284
+ temb: torch.FloatTensor,
285
+ image_rotary_emb=None,
286
+ condition_latents: torch.FloatTensor = None,
287
+ cond_temb: torch.FloatTensor = None,
288
+ cond_rotary_emb=None,
289
+ model_config: Optional[Dict[str, Any]] = {},
290
+ ):
291
+
292
+ using_cond = condition_latents is not None
293
+ residual = hidden_states
294
+ with enable_lora(
295
+ (
296
+ self.norm.linear,
297
+ self.proj_mlp,
298
+ ),
299
+ model_config.get("latent_lora", False),
300
+ ):
301
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
302
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
303
+ if using_cond:
304
+ residual_cond = condition_latents
305
+ norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
306
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
307
+
308
+ attn_output = attn_forward(
309
+ self.attn,
310
+ model_config=model_config,
311
+ hidden_states=norm_hidden_states,
312
+ image_rotary_emb=image_rotary_emb,
313
+ **(
314
+ {
315
+ "condition_latents": norm_condition_latents,
316
+ "cond_rotary_emb": cond_rotary_emb if using_cond else None,
317
+ }
318
+ if using_cond
319
+ else {}
320
+ ),
321
+ )
322
+ if using_cond:
323
+ attn_output, cond_attn_output = attn_output
324
+
325
+ with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
326
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
327
+ gate = gate.unsqueeze(1)
328
+ hidden_states = gate * self.proj_out(hidden_states)
329
+ hidden_states = residual + hidden_states
330
+ if using_cond:
331
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
332
+ cond_gate = cond_gate.unsqueeze(1)
333
+ condition_latents = cond_gate * self.proj_out(condition_latents)
334
+ condition_latents = residual_cond + condition_latents
335
+
336
+ if hidden_states.dtype == torch.float16:
337
+ hidden_states = hidden_states.clip(-65504, 65504)
338
+
339
+ return hidden_states if not using_cond else (hidden_states, condition_latents)
flux/condition.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Union, List, Tuple
3
+ from diffusers.pipelines import FluxPipeline
4
+ from PIL import Image, ImageFilter
5
+ import numpy as np
6
+ import cv2
7
+
8
+ from .pipeline_tools import encode_images
9
+
10
+ condition_dict = {
11
+ "depth": 0,
12
+ "canny": 1,
13
+ "subject": 4,
14
+ "coloring": 6,
15
+ "deblurring": 7,
16
+ "depth_pred": 8,
17
+ "fill": 9,
18
+ "sr": 10,
19
+ "cartoon": 11,
20
+ }
21
+
22
+
23
+ class Condition(object):
24
+ def __init__(
25
+ self,
26
+ condition_type: str,
27
+ raw_img: Union[Image.Image, torch.Tensor] = None,
28
+ condition: Union[Image.Image, torch.Tensor] = None,
29
+ mask=None,
30
+ position_delta=None,
31
+ position_scale=1.0,
32
+ ) -> None:
33
+ self.condition_type = condition_type
34
+ assert raw_img is not None or condition is not None
35
+ if raw_img is not None:
36
+ self.condition = self.get_condition(condition_type, raw_img)
37
+ else:
38
+ self.condition = condition
39
+ self.position_delta = position_delta
40
+ self.position_scale = position_scale
41
+ # TODO: Add mask support
42
+ assert mask is None, "Mask not supported yet"
43
+
44
+ def get_condition(
45
+ self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
46
+ ) -> Union[Image.Image, torch.Tensor]:
47
+ """
48
+ Returns the condition image.
49
+ """
50
+ if condition_type == "depth":
51
+ from transformers import pipeline
52
+
53
+ depth_pipe = pipeline(
54
+ task="depth-estimation",
55
+ model="LiheYoung/depth-anything-small-hf",
56
+ device="cuda",
57
+ )
58
+ source_image = raw_img.convert("RGB")
59
+ condition_img = depth_pipe(source_image)["depth"].convert("RGB")
60
+ return condition_img
61
+ elif condition_type == "canny":
62
+ img = np.array(raw_img)
63
+ edges = cv2.Canny(img, 100, 200)
64
+ edges = Image.fromarray(edges).convert("RGB")
65
+ return edges
66
+ elif condition_type == "subject":
67
+ return raw_img
68
+ elif condition_type == "coloring":
69
+ return raw_img.convert("L").convert("RGB")
70
+ elif condition_type == "deblurring":
71
+ condition_image = (
72
+ raw_img.convert("RGB")
73
+ .filter(ImageFilter.GaussianBlur(10))
74
+ .convert("RGB")
75
+ )
76
+ return condition_image
77
+ elif condition_type == "fill":
78
+ return raw_img.convert("RGB")
79
+ elif condition_type == "cartoon":
80
+ return raw_img.convert("RGB")
81
+ return self.condition
82
+
83
+ @property
84
+ def type_id(self) -> int:
85
+ """
86
+ Returns the type id of the condition.
87
+ """
88
+ return condition_dict[self.condition_type]
89
+
90
+ @classmethod
91
+ def get_type_id(cls, condition_type: str) -> int:
92
+ """
93
+ Returns the type id of the condition.
94
+ """
95
+ return condition_dict[condition_type]
96
+
97
+ def encode(
98
+ self, pipe: FluxPipeline, empty: bool = False
99
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
100
+ """
101
+ Encodes the condition into tokens, ids and type_id.
102
+ """
103
+ if self.condition_type in [
104
+ "depth",
105
+ "canny",
106
+ "subject",
107
+ "coloring",
108
+ "deblurring",
109
+ "depth_pred",
110
+ "fill",
111
+ "sr",
112
+ "cartoon",
113
+ ]:
114
+ if empty:
115
+ # make the condition black
116
+ e_condition = Image.new("RGB", self.condition.size, (0, 0, 0))
117
+ e_condition = e_condition.convert("RGB")
118
+ tokens, ids = encode_images(pipe, e_condition)
119
+ else:
120
+ tokens, ids = encode_images(pipe, self.condition)
121
+ tokens, ids = encode_images(pipe, self.condition)
122
+ else:
123
+ raise NotImplementedError(
124
+ f"Condition type {self.condition_type} not implemented"
125
+ )
126
+ if self.position_delta is None and self.condition_type == "subject":
127
+ self.position_delta = [0, -self.condition.size[0] // 16]
128
+ if self.position_delta is not None:
129
+ ids[:, 1] += self.position_delta[0]
130
+ ids[:, 2] += self.position_delta[1]
131
+ if self.position_scale != 1.0:
132
+ scale_bias = (self.position_scale - 1.0) / 2
133
+ ids[:, 1] *= self.position_scale
134
+ ids[:, 2] *= self.position_scale
135
+ ids[:, 1] += scale_bias
136
+ ids[:, 2] += scale_bias
137
+ type_id = torch.ones_like(ids[:, :1]) * self.type_id
138
+ return tokens, ids, type_id
flux/generate.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import torch
5
+ import yaml
6
+ from diffusers.pipelines import FluxPipeline
7
+ from diffusers.pipelines.flux.pipeline_flux import (
8
+ FluxPipelineOutput,
9
+ calculate_shift,
10
+ np,
11
+ retrieve_timesteps,
12
+ )
13
+
14
+ from .condition import Condition
15
+ from .transformer import tranformer_forward
16
+ from .padding_orthogonalization import apply_padding_token_orthogonalization
17
+
18
+
19
+ def get_config(config_path: str = None):
20
+ config_path = config_path or os.environ.get("XFL_CONFIG")
21
+ if not config_path:
22
+ return {}
23
+ with open(config_path, "r") as f:
24
+ config = yaml.safe_load(f)
25
+ return config
26
+
27
+
28
+ def prepare_params(
29
+ prompt: Union[str, List[str]] = None,
30
+ prompt_2: Optional[Union[str, List[str]]] = None,
31
+ height: Optional[int] = 512,
32
+ width: Optional[int] = 512,
33
+ num_inference_steps: int = 28,
34
+ timesteps: List[int] = None,
35
+ guidance_scale: float = 3.5,
36
+ num_images_per_prompt: Optional[int] = 1,
37
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
38
+ latents: Optional[torch.FloatTensor] = None,
39
+ prompt_embeds: Optional[torch.FloatTensor] = None,
40
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
41
+ output_type: Optional[str] = "pil",
42
+ return_dict: bool = True,
43
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
44
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
45
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
46
+ max_sequence_length: int = 512,
47
+ **kwargs: dict,
48
+ ):
49
+ return (
50
+ prompt,
51
+ prompt_2,
52
+ height,
53
+ width,
54
+ num_inference_steps,
55
+ timesteps,
56
+ guidance_scale,
57
+ num_images_per_prompt,
58
+ generator,
59
+ latents,
60
+ prompt_embeds,
61
+ pooled_prompt_embeds,
62
+ output_type,
63
+ return_dict,
64
+ joint_attention_kwargs,
65
+ callback_on_step_end,
66
+ callback_on_step_end_tensor_inputs,
67
+ max_sequence_length,
68
+ )
69
+
70
+
71
+ def seed_everything(seed: int = 42):
72
+ torch.backends.cudnn.deterministic = True
73
+ torch.manual_seed(seed)
74
+ np.random.seed(seed)
75
+
76
+
77
+ @torch.no_grad()
78
+ def generate(
79
+ pipeline: FluxPipeline,
80
+ conditions: List[Condition] = None,
81
+ config_path: str = None,
82
+ model_config: Optional[Dict[str, Any]] = {},
83
+ condition_scale: float = 1.0,
84
+ default_lora: bool = False,
85
+ default_lora_path: str = None,
86
+ image_guidance_scale: float = 1.0,
87
+ **params: dict,
88
+ ):
89
+ """
90
+ Enhanced Flux text-to-image generation with padding token orthogonalization.
91
+
92
+ This function implements the padding token orthogonalization method from the poster
93
+ "Enhanced Text-to-Image Generation via Padding Token Orthogonalization" to improve
94
+ text-image alignment quality.
95
+
96
+ Args:
97
+ pipeline: FluxPipeline instance
98
+ conditions: List of condition objects
99
+ config_path: Path to configuration file
100
+ model_config: Model configuration dictionary. Supports:
101
+ - padding_orthogonalization_enabled (bool): Enable/disable orthogonalization (default: True)
102
+ - preserve_norm (bool): Preserve original embedding norms (default: True)
103
+ - orthogonalize_all_tokens (bool): Orthogonalize all tokens vs only padding (default: False)
104
+ condition_scale: Scale factor for conditions
105
+ default_lora: Whether to use default LoRA
106
+ default_lora_path: Path to default LoRA weights
107
+ image_guidance_scale: Scale for image guidance
108
+ **params: Additional generation parameters
109
+
110
+ Returns:
111
+ Generated images with enhanced text-image alignment
112
+ """
113
+ model_config = model_config or get_config(config_path).get("model", {})
114
+ if condition_scale != 1:
115
+ for name, module in pipeline.transformer.named_modules():
116
+ if not name.endswith(".attn"):
117
+ continue
118
+ module.c_factor = torch.ones(1, 1) * condition_scale
119
+ if default_lora and default_lora_path:
120
+ pipeline.load_lora_weights(default_lora_path)
121
+
122
+ self = pipeline
123
+ (
124
+ prompt,
125
+ prompt_2,
126
+ height,
127
+ width,
128
+ num_inference_steps,
129
+ timesteps,
130
+ guidance_scale,
131
+ num_images_per_prompt,
132
+ generator,
133
+ latents,
134
+ prompt_embeds,
135
+ pooled_prompt_embeds,
136
+ output_type,
137
+ return_dict,
138
+ joint_attention_kwargs,
139
+ callback_on_step_end,
140
+ callback_on_step_end_tensor_inputs,
141
+ max_sequence_length,
142
+ ) = prepare_params(**params)
143
+
144
+ height = height or self.default_sample_size * self.vae_scale_factor
145
+ width = width or self.default_sample_size * self.vae_scale_factor
146
+
147
+ # 1. Check inputs. Raise error if not correct
148
+ self.check_inputs(
149
+ prompt,
150
+ prompt_2,
151
+ height,
152
+ width,
153
+ prompt_embeds=prompt_embeds,
154
+ pooled_prompt_embeds=pooled_prompt_embeds,
155
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
156
+ max_sequence_length=max_sequence_length,
157
+ )
158
+
159
+ self._guidance_scale = guidance_scale
160
+ self._joint_attention_kwargs = joint_attention_kwargs
161
+ self._interrupt = False
162
+
163
+ # 2. Define call parameters
164
+ if prompt is not None and isinstance(prompt, str):
165
+ batch_size = 1
166
+ elif prompt is not None and isinstance(prompt, list):
167
+ batch_size = len(prompt)
168
+ else:
169
+ batch_size = prompt_embeds.shape[0]
170
+
171
+ device = self._execution_device
172
+
173
+ lora_scale = (
174
+ self.joint_attention_kwargs.get("scale", None)
175
+ if self.joint_attention_kwargs is not None
176
+ else None
177
+ )
178
+ (
179
+ prompt_embeds,
180
+ pooled_prompt_embeds,
181
+ text_ids,
182
+ ) = self.encode_prompt(
183
+ prompt=prompt,
184
+ prompt_2=prompt_2,
185
+ prompt_embeds=prompt_embeds,
186
+ pooled_prompt_embeds=pooled_prompt_embeds,
187
+ device=device,
188
+ num_images_per_prompt=num_images_per_prompt,
189
+ max_sequence_length=max_sequence_length,
190
+ lora_scale=lora_scale,
191
+ )
192
+
193
+ # Apply Padding Token Orthogonalization for enhanced text-image alignment
194
+ if model_config.get('padding_orthogonalization_enabled', True):
195
+ prompt_embeds = apply_padding_token_orthogonalization(
196
+ prompt_embeds=prompt_embeds,
197
+ text_attention_mask=None, # Will use heuristic if not available
198
+ config=model_config,
199
+ )
200
+
201
+ # 4. Prepare latent variables
202
+ num_channels_latents = self.transformer.config.in_channels // 4
203
+ latents, latent_image_ids = self.prepare_latents(
204
+ batch_size * num_images_per_prompt,
205
+ num_channels_latents,
206
+ height,
207
+ width,
208
+ prompt_embeds.dtype,
209
+ device,
210
+ generator,
211
+ latents,
212
+ )
213
+
214
+ # 4.1. Prepare conditions
215
+ condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
216
+ use_condition = conditions is not None or []
217
+ if use_condition:
218
+ assert len(conditions) <= 1, "Only one condition is supported for now."
219
+ if not default_lora:
220
+ pipeline.set_adapters(conditions[0].condition_type)
221
+ for condition in conditions:
222
+ print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
223
+ print(f"Condition: {condition.condition_type}")
224
+ tokens, ids, type_id = condition.encode(self)
225
+ condition_latents.append(tokens) # [batch_size, token_n, token_dim]
226
+ condition_ids.append(ids) # [token_n, id_dim(3)]
227
+ condition_type_ids.append(type_id) # [token_n, 1]
228
+ condition_latents = torch.cat(condition_latents, dim=1)
229
+ condition_ids = torch.cat(condition_ids, dim=0)
230
+ condition_type_ids = torch.cat(condition_type_ids, dim=0)
231
+
232
+ # 5. Prepare timesteps
233
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
234
+ image_seq_len = latents.shape[1]
235
+ mu = calculate_shift(
236
+ image_seq_len,
237
+ self.scheduler.config.base_image_seq_len,
238
+ self.scheduler.config.max_image_seq_len,
239
+ self.scheduler.config.base_shift,
240
+ self.scheduler.config.max_shift,
241
+ )
242
+ timesteps, num_inference_steps = retrieve_timesteps(
243
+ self.scheduler,
244
+ num_inference_steps,
245
+ device,
246
+ timesteps,
247
+ sigmas,
248
+ mu=mu,
249
+ )
250
+ num_warmup_steps = max(
251
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
252
+ )
253
+ self._num_timesteps = len(timesteps)
254
+
255
+ # 6. Denoising loop
256
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
257
+ for i, t in enumerate(timesteps):
258
+ if self.interrupt:
259
+ continue
260
+
261
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
262
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
263
+
264
+ # handle guidance
265
+ if self.transformer.config.guidance_embeds:
266
+ guidance = torch.tensor([guidance_scale], device=device)
267
+ guidance = guidance.expand(latents.shape[0])
268
+ else:
269
+ guidance = None
270
+ noise_pred = tranformer_forward(
271
+ self.transformer,
272
+ model_config=model_config,
273
+ # Inputs of the condition (new feature)
274
+ condition_latents=condition_latents if use_condition else None,
275
+ condition_ids=condition_ids if use_condition else None,
276
+ condition_type_ids=condition_type_ids if use_condition else None,
277
+ # Inputs to the original transformer
278
+ hidden_states=latents,
279
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
280
+ timestep=timestep / 1000,
281
+ guidance=guidance,
282
+ pooled_projections=pooled_prompt_embeds,
283
+ encoder_hidden_states=prompt_embeds,
284
+ txt_ids=text_ids,
285
+ img_ids=latent_image_ids,
286
+ joint_attention_kwargs=self.joint_attention_kwargs,
287
+ return_dict=False,
288
+ )[0]
289
+
290
+ if image_guidance_scale != 1.0:
291
+ uncondition_latents = condition.encode(self, empty=True)[0]
292
+ # 修复:在 guidance 为 None 的情况下,创建适当的替代张量
293
+ # 创建一个形状为 [latents.shape[0]] 的全 1 张量
294
+ guidance_replacement = torch.ones(latents.shape[0], device=device)
295
+ unc_pred = tranformer_forward(
296
+ self.transformer,
297
+ model_config=model_config,
298
+ # Inputs of the condition (new feature)
299
+ condition_latents=uncondition_latents if use_condition else None,
300
+ condition_ids=condition_ids if use_condition else None,
301
+ condition_type_ids=condition_type_ids if use_condition else None,
302
+ # Inputs to the original transformer
303
+ hidden_states=latents,
304
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
305
+ timestep=timestep / 1000,
306
+ # guidance=torch.ones_like(guidance),
307
+ guidance=guidance_replacement,
308
+ pooled_projections=pooled_prompt_embeds,
309
+ encoder_hidden_states=prompt_embeds,
310
+ txt_ids=text_ids,
311
+ img_ids=latent_image_ids,
312
+ # joint_attention_kwargs=self.joint_attention_kwargs,
313
+ joint_attention_kwargs=None,
314
+ return_dict=False,
315
+ )[0]
316
+
317
+ noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
318
+
319
+ # compute the previous noisy sample x_t -> x_t-1
320
+ latents_dtype = latents.dtype
321
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
322
+
323
+ if latents.dtype != latents_dtype:
324
+ if torch.backends.mps.is_available():
325
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
326
+ latents = latents.to(latents_dtype)
327
+
328
+ if callback_on_step_end is not None:
329
+ callback_kwargs = {}
330
+ for k in callback_on_step_end_tensor_inputs:
331
+ callback_kwargs[k] = locals()[k]
332
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
333
+
334
+ latents = callback_outputs.pop("latents", latents)
335
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
336
+
337
+ # call the callback, if provided
338
+ if i == len(timesteps) - 1 or (
339
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
340
+ ):
341
+ progress_bar.update()
342
+
343
+ if output_type == "latent":
344
+ image = latents
345
+
346
+ else:
347
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
348
+ latents = (
349
+ latents / self.vae.config.scaling_factor
350
+ ) + self.vae.config.shift_factor
351
+ image = self.vae.decode(latents, return_dict=False)[0]
352
+ image = self.image_processor.postprocess(image, output_type=output_type)
353
+
354
+ # Offload all models
355
+ self.maybe_free_model_hooks()
356
+
357
+ if condition_scale != 1:
358
+ for name, module in pipeline.transformer.named_modules():
359
+ if not name.endswith(".attn"):
360
+ continue
361
+ del module.c_factor
362
+
363
+ if not return_dict:
364
+ return (image,)
365
+
366
+ return FluxPipelineOutput(images=image)
flux/lora_controller.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from peft.tuners.tuners_utils import BaseTunerLayer
2
+ from typing import List, Any, Optional, Type
3
+ from .condition import condition_dict
4
+
5
+ class enable_lora:
6
+ def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
7
+ self.activated: bool = activated
8
+ if activated:
9
+ return
10
+ self.lora_modules: List[BaseTunerLayer] = [
11
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
12
+ ]
13
+ self.scales = [
14
+ {
15
+ active_adapter: lora_module.scaling[active_adapter]
16
+ for active_adapter in lora_module.active_adapters
17
+ }
18
+ for lora_module in self.lora_modules
19
+ ]
20
+
21
+ def __enter__(self) -> None:
22
+ if self.activated:
23
+ return
24
+
25
+ for lora_module in self.lora_modules:
26
+ if not isinstance(lora_module, BaseTunerLayer):
27
+ continue
28
+ for active_adapter in lora_module.active_adapters:
29
+ if active_adapter in condition_dict.keys():
30
+ lora_module.scaling[active_adapter] = 0.0
31
+
32
+ def __exit__(
33
+ self,
34
+ exc_type: Optional[Type[BaseException]],
35
+ exc_val: Optional[BaseException],
36
+ exc_tb: Optional[Any],
37
+ ) -> None:
38
+ if self.activated:
39
+ return
40
+ for i, lora_module in enumerate(self.lora_modules):
41
+ if not isinstance(lora_module, BaseTunerLayer):
42
+ continue
43
+ for active_adapter in lora_module.active_adapters:
44
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
45
+
46
+
47
+ class set_lora_scale:
48
+ def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
49
+ self.lora_modules: List[BaseTunerLayer] = [
50
+ each for each in lora_modules if isinstance(each, BaseTunerLayer)
51
+ ]
52
+ self.scales = [
53
+ {
54
+ active_adapter: lora_module.scaling[active_adapter]
55
+ for active_adapter in lora_module.active_adapters
56
+ }
57
+ for lora_module in self.lora_modules
58
+ ]
59
+ self.scale = scale
60
+
61
+ def __enter__(self) -> None:
62
+ for lora_module in self.lora_modules:
63
+ if not isinstance(lora_module, BaseTunerLayer):
64
+ continue
65
+ lora_module.scale_layer(self.scale)
66
+
67
+ def __exit__(
68
+ self,
69
+ exc_type: Optional[Type[BaseException]],
70
+ exc_val: Optional[BaseException],
71
+ exc_tb: Optional[Any],
72
+ ) -> None:
73
+ for i, lora_module in enumerate(self.lora_modules):
74
+ if not isinstance(lora_module, BaseTunerLayer):
75
+ continue
76
+ for active_adapter in lora_module.active_adapters:
77
+ lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
flux/padding_orthogonalization.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced Text-to-Image Generation via Padding Token Orthogonalization
3
+
4
+ This module implements the padding token orthogonalization method described in the poster
5
+ "Enhanced Text-to-Image Generation via Padding Token Orthogonalization" by Jiafeng Mao,
6
+ Qianru Qiu, Xueting Wang from CyberAgent AI Lab.
7
+
8
+ The core idea is to use padding tokens as registers that collect, store, and redistribute
9
+ features across layers via attention pathways through Gram-Schmidt orthogonalization.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from typing import Optional, Tuple
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def orthogonalize_rows(X: torch.Tensor) -> torch.Tensor:
21
+ """
22
+ Orthogonalize rows of matrix X using QR decomposition.
23
+
24
+ This is the core function from the poster: Q, _ = torch.linalg.qr(X.T) return Q.T
25
+
26
+ Args:
27
+ X: Input tensor of shape (..., n_rows, n_cols)
28
+
29
+ Returns:
30
+ Orthogonalized tensor of the same shape
31
+ """
32
+ # Save original dtype and convert to float32 for QR decomposition
33
+ original_dtype = X.dtype
34
+ original_shape = X.shape
35
+
36
+ # Convert to float32 if needed (QR doesn't support bfloat16)
37
+ if X.dtype == torch.bfloat16:
38
+ X = X.to(torch.float32)
39
+
40
+ # Handle batch dimensions by flattening
41
+ if X.dim() > 2:
42
+ # Reshape to (batch_size, n_rows, n_cols)
43
+ X_flat = X.view(-1, original_shape[-2], original_shape[-1])
44
+ results = []
45
+
46
+ for i in range(X_flat.shape[0]):
47
+ # Apply QR decomposition: Q, _ = torch.linalg.qr(X.T)
48
+ Q, _ = torch.linalg.qr(X_flat[i].T)
49
+ # Return Q.T to get orthogonalized rows
50
+ results.append(Q.T)
51
+
52
+ result = torch.stack(results, dim=0)
53
+ # Reshape back to original shape
54
+ result = result.view(original_shape)
55
+ else:
56
+ # Simple 2D case
57
+ Q, _ = torch.linalg.qr(X.T)
58
+ result = Q.T
59
+
60
+ # Convert back to original dtype
61
+ if original_dtype == torch.bfloat16:
62
+ result = result.to(original_dtype)
63
+
64
+ return result
65
+
66
+
67
+ class PaddingTokenOrthogonalizer(nn.Module):
68
+ """
69
+ A module that applies padding token orthogonalization to text embeddings.
70
+
71
+ Based on the poster's method, this enhances text-image alignment by:
72
+ 1. Identifying padding tokens in the sequence
73
+ 2. Orthogonalizing their representations using QR decomposition
74
+ 3. Maintaining feature diversity and preventing biased attention
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ enabled: bool = True,
80
+ preserve_norm: bool = True,
81
+ orthogonalize_all: bool = False,
82
+ ):
83
+ """
84
+ Args:
85
+ enabled: Whether to apply orthogonalization
86
+ preserve_norm: Whether to preserve the original norm of tokens
87
+ orthogonalize_all: If True, orthogonalize all tokens; if False, only padding tokens
88
+ """
89
+ super().__init__()
90
+ self.enabled = enabled
91
+ self.preserve_norm = preserve_norm
92
+ self.orthogonalize_all = orthogonalize_all
93
+
94
+ def identify_padding_tokens(
95
+ self,
96
+ embeddings: torch.Tensor,
97
+ attention_mask: Optional[torch.Tensor] = None,
98
+ pad_token_id: Optional[int] = None,
99
+ input_ids: Optional[torch.Tensor] = None
100
+ ) -> torch.Tensor:
101
+ """
102
+ Identify padding token positions in the sequence.
103
+
104
+ Args:
105
+ embeddings: Token embeddings [batch, seq_len, hidden_size]
106
+ attention_mask: Attention mask where 0 indicates padding
107
+ pad_token_id: ID of the padding token
108
+ input_ids: Input token IDs
109
+
110
+ Returns:
111
+ Boolean mask indicating padding positions [batch, seq_len]
112
+ """
113
+ batch_size, seq_len = embeddings.shape[:2]
114
+
115
+ if attention_mask is not None:
116
+ # Attention mask: 1 for real tokens, 0 for padding
117
+ return ~attention_mask.bool()
118
+ elif pad_token_id is not None and input_ids is not None:
119
+ return input_ids == pad_token_id
120
+ else:
121
+ # Fallback: assume last 25% of sequence are padding tokens
122
+ # This is a heuristic based on common practice
123
+ padding_start = int(seq_len * 0.75)
124
+ mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=embeddings.device)
125
+ mask[:, padding_start:] = True
126
+ return mask
127
+
128
+ def forward(
129
+ self,
130
+ embeddings: torch.Tensor,
131
+ attention_mask: Optional[torch.Tensor] = None,
132
+ pad_token_id: Optional[int] = None,
133
+ input_ids: Optional[torch.Tensor] = None,
134
+ ) -> torch.Tensor:
135
+ """
136
+ Apply padding token orthogonalization.
137
+
138
+ Args:
139
+ embeddings: Token embeddings [batch, seq_len, hidden_size]
140
+ attention_mask: Attention mask where 1 indicates real tokens
141
+ pad_token_id: ID of the padding token
142
+ input_ids: Input token IDs
143
+
144
+ Returns:
145
+ Enhanced embeddings with orthogonalized padding tokens
146
+ """
147
+ if not self.enabled:
148
+ return embeddings
149
+
150
+ # Store original norms if we need to preserve them
151
+ if self.preserve_norm:
152
+ original_norms = torch.norm(embeddings, dim=-1, keepdim=True)
153
+
154
+ if self.orthogonalize_all:
155
+ # Orthogonalize all tokens in the sequence
156
+ enhanced_embeddings = orthogonalize_rows(embeddings)
157
+ else:
158
+ # Only orthogonalize padding tokens
159
+ padding_mask = self.identify_padding_tokens(
160
+ embeddings, attention_mask, pad_token_id, input_ids
161
+ )
162
+
163
+ enhanced_embeddings = embeddings.clone()
164
+
165
+ # Process each sample in the batch
166
+ for batch_idx in range(embeddings.shape[0]):
167
+ padding_indices = torch.where(padding_mask[batch_idx])[0]
168
+
169
+ if len(padding_indices) > 1: # Need at least 2 tokens to orthogonalize
170
+ # Extract padding token embeddings
171
+ padding_embeddings = embeddings[batch_idx, padding_indices]
172
+
173
+ # Apply orthogonalization
174
+ orthogonalized = orthogonalize_rows(padding_embeddings)
175
+
176
+ # Put back orthogonalized embeddings
177
+ enhanced_embeddings[batch_idx, padding_indices] = orthogonalized
178
+
179
+ # Restore original norms if requested
180
+ if self.preserve_norm:
181
+ current_norms = torch.norm(enhanced_embeddings, dim=-1, keepdim=True)
182
+ enhanced_embeddings = enhanced_embeddings * (original_norms / (current_norms + 1e-8))
183
+
184
+ return enhanced_embeddings
185
+
186
+
187
+ def apply_padding_token_orthogonalization(
188
+ prompt_embeds: torch.Tensor,
189
+ text_attention_mask: Optional[torch.Tensor] = None,
190
+ config: Optional[dict] = None,
191
+ ) -> torch.Tensor:
192
+ """
193
+ Convenience function to apply padding token orthogonalization to prompt embeddings.
194
+
195
+ Args:
196
+ prompt_embeds: Text prompt embeddings [batch, seq_len, hidden_size]
197
+ text_attention_mask: Attention mask for text tokens
198
+ config: Configuration dictionary with orthogonalization settings
199
+
200
+ Returns:
201
+ Enhanced prompt embeddings
202
+ """
203
+ if config is None:
204
+ config = {}
205
+
206
+ orthogonalizer = PaddingTokenOrthogonalizer(
207
+ enabled=config.get('padding_orthogonalization_enabled', True),
208
+ preserve_norm=config.get('preserve_norm', True),
209
+ orthogonalize_all=config.get('orthogonalize_all_tokens', False),
210
+ )
211
+
212
+ return orthogonalizer(
213
+ embeddings=prompt_embeds,
214
+ attention_mask=text_attention_mask,
215
+ )
216
+
217
+
218
+ # Gram-Schmidt orthogonalization alternative implementation
219
+ def gram_schmidt_orthogonalization(vectors: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
220
+ """
221
+ Alternative implementation using explicit Gram-Schmidt process.
222
+ This provides more control but is generally slower than QR decomposition.
223
+
224
+ Args:
225
+ vectors: Input vectors to orthogonalize [n_vectors, dim]
226
+ eps: Small epsilon for numerical stability
227
+
228
+ Returns:
229
+ Orthogonalized vectors
230
+ """
231
+ n_vectors = vectors.shape[0]
232
+ orthogonal_vectors = torch.zeros_like(vectors)
233
+
234
+ for i in range(n_vectors):
235
+ vector = vectors[i].clone()
236
+
237
+ # Subtract projections onto previous orthogonal vectors
238
+ for j in range(i):
239
+ projection = torch.dot(vector, orthogonal_vectors[j]) / (
240
+ torch.dot(orthogonal_vectors[j], orthogonal_vectors[j]) + eps
241
+ )
242
+ vector = vector - projection * orthogonal_vectors[j]
243
+
244
+ # Normalize
245
+ norm = torch.norm(vector)
246
+ if norm > eps:
247
+ orthogonal_vectors[i] = vector / norm
248
+ else:
249
+ # Handle zero vector case
250
+ orthogonal_vectors[i] = vector
251
+
252
+ return orthogonal_vectors
flux/pipeline_tools.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.pipelines import FluxPipeline
2
+ from diffusers.utils import logging
3
+ from diffusers.pipelines.flux.pipeline_flux import logger
4
+ from torch import Tensor
5
+ from typing import Optional, Dict, Any
6
+ from .padding_orthogonalization import apply_padding_token_orthogonalization
7
+
8
+
9
+ def encode_images(pipeline: FluxPipeline, images: Tensor):
10
+ images = pipeline.image_processor.preprocess(images)
11
+ images = images.to(pipeline.device).to(pipeline.dtype)
12
+ images = pipeline.vae.encode(images).latent_dist.sample()
13
+ images = (
14
+ images - pipeline.vae.config.shift_factor
15
+ ) * pipeline.vae.config.scaling_factor
16
+ images_tokens = pipeline._pack_latents(images, *images.shape)
17
+ images_ids = pipeline._prepare_latent_image_ids(
18
+ images.shape[0],
19
+ images.shape[2],
20
+ images.shape[3],
21
+ pipeline.device,
22
+ pipeline.dtype,
23
+ )
24
+ if images_tokens.shape[1] != images_ids.shape[0]:
25
+ images_ids = pipeline._prepare_latent_image_ids(
26
+ images.shape[0],
27
+ images.shape[2] // 2,
28
+ images.shape[3] // 2,
29
+ pipeline.device,
30
+ pipeline.dtype,
31
+ )
32
+ return images_tokens, images_ids
33
+
34
+
35
+ def prepare_text_input(
36
+ pipeline: FluxPipeline,
37
+ prompts,
38
+ max_sequence_length=512,
39
+ model_config: Optional[Dict[str, Any]] = None
40
+ ):
41
+ """
42
+ Prepare text input with optional padding token orthogonalization.
43
+
44
+ Args:
45
+ pipeline: FluxPipeline instance
46
+ prompts: Text prompts to encode
47
+ max_sequence_length: Maximum sequence length
48
+ model_config: Optional configuration for orthogonalization
49
+
50
+ Returns:
51
+ Tuple of (prompt_embeds, pooled_prompt_embeds, text_ids)
52
+ """
53
+ # Turn off warnings (CLIP overflow)
54
+ logger.setLevel(logging.ERROR)
55
+ (
56
+ prompt_embeds,
57
+ pooled_prompt_embeds,
58
+ text_ids,
59
+ ) = pipeline.encode_prompt(
60
+ prompt=prompts,
61
+ prompt_2=None,
62
+ prompt_embeds=None,
63
+ pooled_prompt_embeds=None,
64
+ device=pipeline.device,
65
+ num_images_per_prompt=1,
66
+ max_sequence_length=max_sequence_length,
67
+ lora_scale=None,
68
+ )
69
+
70
+ # Apply padding token orthogonalization if configured
71
+ if model_config and model_config.get('padding_orthogonalization_enabled', False):
72
+ prompt_embeds = apply_padding_token_orthogonalization(
73
+ prompt_embeds=prompt_embeds,
74
+ text_attention_mask=None,
75
+ config=model_config,
76
+ )
77
+
78
+ # Turn on warnings
79
+ logger.setLevel(logging.WARNING)
80
+ return prompt_embeds, pooled_prompt_embeds, text_ids
flux/transformer.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines import FluxPipeline
3
+ from typing import List, Union, Optional, Dict, Any, Callable
4
+ from .block import block_forward, single_block_forward
5
+ from .lora_controller import enable_lora
6
+ from accelerate.utils import is_torch_version
7
+ from diffusers.models.transformers.transformer_flux import (
8
+ FluxTransformer2DModel,
9
+ Transformer2DModelOutput,
10
+ USE_PEFT_BACKEND,
11
+ scale_lora_layers,
12
+ unscale_lora_layers,
13
+ logger,
14
+ )
15
+ import numpy as np
16
+
17
+
18
+ def prepare_params(
19
+ hidden_states: torch.Tensor,
20
+ encoder_hidden_states: torch.Tensor = None,
21
+ pooled_projections: torch.Tensor = None,
22
+ timestep: torch.LongTensor = None,
23
+ img_ids: torch.Tensor = None,
24
+ txt_ids: torch.Tensor = None,
25
+ guidance: torch.Tensor = None,
26
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
27
+ controlnet_block_samples=None,
28
+ controlnet_single_block_samples=None,
29
+ return_dict: bool = True,
30
+ **kwargs: dict,
31
+ ):
32
+ return (
33
+ hidden_states,
34
+ encoder_hidden_states,
35
+ pooled_projections,
36
+ timestep,
37
+ img_ids,
38
+ txt_ids,
39
+ guidance,
40
+ joint_attention_kwargs,
41
+ controlnet_block_samples,
42
+ controlnet_single_block_samples,
43
+ return_dict,
44
+ )
45
+
46
+
47
+ def tranformer_forward(
48
+ transformer: FluxTransformer2DModel,
49
+ condition_latents: torch.Tensor,
50
+ condition_ids: torch.Tensor,
51
+ condition_type_ids: torch.Tensor,
52
+ model_config: Optional[Dict[str, Any]] = {},
53
+ c_t=0,
54
+ **params: dict,
55
+ ):
56
+ self = transformer
57
+ use_condition = condition_latents is not None
58
+
59
+ (
60
+ hidden_states,
61
+ encoder_hidden_states,
62
+ pooled_projections,
63
+ timestep,
64
+ img_ids,
65
+ txt_ids,
66
+ guidance,
67
+ joint_attention_kwargs,
68
+ controlnet_block_samples,
69
+ controlnet_single_block_samples,
70
+ return_dict,
71
+ ) = prepare_params(**params)
72
+
73
+ if joint_attention_kwargs is not None:
74
+ joint_attention_kwargs = joint_attention_kwargs.copy()
75
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
76
+ else:
77
+ lora_scale = 1.0
78
+
79
+ if USE_PEFT_BACKEND:
80
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
81
+ scale_lora_layers(self, lora_scale)
82
+ else:
83
+ if (
84
+ joint_attention_kwargs is not None
85
+ and joint_attention_kwargs.get("scale", None) is not None
86
+ ):
87
+ logger.warning(
88
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
89
+ )
90
+
91
+ with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
92
+ hidden_states = self.x_embedder(hidden_states)
93
+ condition_latents = self.x_embedder(condition_latents) if use_condition else None
94
+
95
+ timestep = timestep.to(hidden_states.dtype) * 1000
96
+
97
+ if guidance is not None:
98
+ guidance = guidance.to(hidden_states.dtype) * 1000
99
+ else:
100
+ guidance = None
101
+
102
+ temb = (
103
+ self.time_text_embed(timestep, pooled_projections)
104
+ if guidance is None
105
+ else self.time_text_embed(timestep, guidance, pooled_projections)
106
+ )
107
+
108
+ cond_temb = (
109
+ self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
110
+ if guidance is None
111
+ else self.time_text_embed(
112
+ torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
113
+ )
114
+ )
115
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
116
+
117
+ if txt_ids.ndim == 3:
118
+ logger.warning(
119
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
120
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
121
+ )
122
+ txt_ids = txt_ids[0]
123
+ if img_ids.ndim == 3:
124
+ logger.warning(
125
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
126
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
127
+ )
128
+ img_ids = img_ids[0]
129
+
130
+ ids = torch.cat((txt_ids, img_ids), dim=0)
131
+ image_rotary_emb = self.pos_embed(ids)
132
+ if use_condition:
133
+ # condition_ids[:, :1] = condition_type_ids
134
+ cond_rotary_emb = self.pos_embed(condition_ids)
135
+
136
+ # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
137
+
138
+ for index_block, block in enumerate(self.transformer_blocks):
139
+ if self.training and self.gradient_checkpointing:
140
+ ckpt_kwargs: Dict[str, Any] = (
141
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
142
+ )
143
+ encoder_hidden_states, hidden_states, condition_latents = (
144
+ torch.utils.checkpoint.checkpoint(
145
+ block_forward,
146
+ self=block,
147
+ model_config=model_config,
148
+ hidden_states=hidden_states,
149
+ encoder_hidden_states=encoder_hidden_states,
150
+ condition_latents=condition_latents if use_condition else None,
151
+ temb=temb,
152
+ cond_temb=cond_temb if use_condition else None,
153
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
154
+ image_rotary_emb=image_rotary_emb,
155
+ **ckpt_kwargs,
156
+ )
157
+ )
158
+
159
+ else:
160
+ encoder_hidden_states, hidden_states, condition_latents = block_forward(
161
+ block,
162
+ model_config=model_config,
163
+ hidden_states=hidden_states,
164
+ encoder_hidden_states=encoder_hidden_states,
165
+ condition_latents=condition_latents if use_condition else None,
166
+ temb=temb,
167
+ cond_temb=cond_temb if use_condition else None,
168
+ cond_rotary_emb=cond_rotary_emb if use_condition else None,
169
+ image_rotary_emb=image_rotary_emb,
170
+ )
171
+
172
+ # controlnet residual
173
+ if controlnet_block_samples is not None:
174
+ interval_control = len(self.transformer_blocks) / len(
175
+ controlnet_block_samples
176
+ )
177
+ interval_control = int(np.ceil(interval_control))
178
+ hidden_states = (
179
+ hidden_states
180
+ + controlnet_block_samples[index_block // interval_control]
181
+ )
182
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
183
+
184
+ for index_block, block in enumerate(self.single_transformer_blocks):
185
+ if self.training and self.gradient_checkpointing:
186
+ ckpt_kwargs: Dict[str, Any] = (
187
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
188
+ )
189
+ result = torch.utils.checkpoint.checkpoint(
190
+ single_block_forward,
191
+ self=block,
192
+ model_config=model_config,
193
+ hidden_states=hidden_states,
194
+ temb=temb,
195
+ image_rotary_emb=image_rotary_emb,
196
+ **(
197
+ {
198
+ "condition_latents": condition_latents,
199
+ "cond_temb": cond_temb,
200
+ "cond_rotary_emb": cond_rotary_emb,
201
+ }
202
+ if use_condition
203
+ else {}
204
+ ),
205
+ **ckpt_kwargs,
206
+ )
207
+
208
+ else:
209
+ result = single_block_forward(
210
+ block,
211
+ model_config=model_config,
212
+ hidden_states=hidden_states,
213
+ temb=temb,
214
+ image_rotary_emb=image_rotary_emb,
215
+ **(
216
+ {
217
+ "condition_latents": condition_latents,
218
+ "cond_temb": cond_temb,
219
+ "cond_rotary_emb": cond_rotary_emb,
220
+ }
221
+ if use_condition
222
+ else {}
223
+ ),
224
+ )
225
+ if use_condition:
226
+ hidden_states, condition_latents = result
227
+ else:
228
+ hidden_states = result
229
+
230
+ # controlnet residual
231
+ if controlnet_single_block_samples is not None:
232
+ interval_control = len(self.single_transformer_blocks) / len(
233
+ controlnet_single_block_samples
234
+ )
235
+ interval_control = int(np.ceil(interval_control))
236
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
237
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
238
+ + controlnet_single_block_samples[index_block // interval_control]
239
+ )
240
+
241
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
242
+
243
+ hidden_states = self.norm_out(hidden_states, temb)
244
+ output = self.proj_out(hidden_states)
245
+
246
+ if USE_PEFT_BACKEND:
247
+ # remove `lora_scale` from each PEFT layer
248
+ unscale_lora_layers(self, lora_scale)
249
+
250
+ if not return_dict:
251
+ return (output,)
252
+ return Transformer2DModelOutput(sample=output)
pyproject.toml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "ads_gen"
3
+ version = "0.1.0"
4
+ description = "ZenCtrl AP"
5
+ authors = [{ name = "Dummy User", email = "dummy@gmail.com" }]
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "diffusers==0.35.0",
9
+ "gradio>=5.29.0",
10
+ "jupyter>=1.1.1",
11
+ "matplotlib>=3.10.3",
12
+ "opencv-python>=4.11.0.86",
13
+ "peft>=0.17.0",
14
+ "protobuf>=4.21.5",
15
+ "sentencepiece>=0.2.0",
16
+ "torchao>=0.10.0",
17
+ "torchvision>=0.22.0",
18
+ "transformers>=4.55.0",
19
+ "datasets>=2.13.0,<3",
20
+ "gcsfs>=2023.1.0,<2024",
21
+ "pillow>=9.5.0,<10",
22
+ "setuptools>=68.0.0,<69",
23
+ "tensorboard>=2.13.0,<3",
24
+ "omegaconf>=2.3.0,<3",
25
+ "einops>=0.6.1,<0.7",
26
+ "scipy>1.10.1",
27
+ "seaborn>=0.12.2,<0.13",
28
+ "tensorflow>=2.12.0,<3",
29
+ "tensorflow-datasets>=4.9.2,<5",
30
+ "hydra-core>=1.3.2,<2",
31
+ "torch-tb-profiler>=0.4.1,<0.5",
32
+ "faiss-cpu>=1.7.4,<2",
33
+ "triton==3.3.0",
34
+ "bitsandbytes==0.45.2",
35
+ "prdc>=0.2,<0.3",
36
+ "pytorch-fid>=0.3.0,<0.4",
37
+ "python-json-logger>=2.0.7,<3",
38
+ "multiprocess>=0.70.12",
39
+ "pyyaml>=6.0.1,<7",
40
+ "timm>=0.9.5,<0.10",
41
+ "rich>=13.5.2,<14",
42
+ "gdown>=4.7.1,<5",
43
+ "dreamsim>=0.1.3",
44
+ "scikit-image>=0.24.0",
45
+ "nvitop>=1.5.0",
46
+ "segment-anything==1.0",
47
+ ]
48
+
49
+ [tool.hatch.build.targets.wheel]
50
+ packages = ["app", "ralf"]
51
+
52
+ [tool.hatch.build.targets.sdist]
53
+ include = ["app", "ralf"]
54
+
55
+ [build-system]
56
+ requires = ["hatchling"]
57
+ build-backend = "hatchling.build"
58
+
59
+ [tool.uv]
60
+ [[tool.uv.index]]
61
+ name = "pytorch-cu124"
62
+ url = "https://download.pytorch.org/whl/cu124"
63
+ explicit = true
64
+
65
+ [tool.uv.sources]
66
+ segment-anything = { git = "https://github.com/facebookresearch/segment-anything.git" }
requirements.txt ADDED
@@ -0,0 +1,938 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml
3
+ absl-py==2.3.1
4
+ # via
5
+ # array-record
6
+ # dm-tree
7
+ # etils
8
+ # keras
9
+ # tensorboard
10
+ # tensorflow
11
+ # tensorflow-datasets
12
+ # tensorflow-metadata
13
+ accelerate==1.10.1
14
+ # via peft
15
+ aiofiles==24.1.0
16
+ # via gradio
17
+ aiohappyeyeballs==2.6.1
18
+ # via aiohttp
19
+ aiohttp==3.13.0
20
+ # via
21
+ # datasets
22
+ # fsspec
23
+ # gcsfs
24
+ aiosignal==1.4.0
25
+ # via aiohttp
26
+ annotated-types==0.7.0
27
+ # via pydantic
28
+ antlr4-python3-runtime==4.9.3
29
+ # via
30
+ # hydra-core
31
+ # omegaconf
32
+ anyio==4.11.0
33
+ # via
34
+ # gradio
35
+ # httpx
36
+ # jupyter-server
37
+ # starlette
38
+ argon2-cffi==25.1.0
39
+ # via jupyter-server
40
+ argon2-cffi-bindings==25.1.0
41
+ # via argon2-cffi
42
+ array-record==0.8.1
43
+ # via tensorflow-datasets
44
+ arrow==1.3.0
45
+ # via isoduration
46
+ asttokens==3.0.0
47
+ # via stack-data
48
+ astunparse==1.6.3
49
+ # via tensorflow
50
+ async-lru==2.0.5
51
+ # via jupyterlab
52
+ attrs==25.4.0
53
+ # via
54
+ # aiohttp
55
+ # dm-tree
56
+ # jsonschema
57
+ # referencing
58
+ babel==2.17.0
59
+ # via jupyterlab-server
60
+ beautifulsoup4==4.14.2
61
+ # via
62
+ # gdown
63
+ # nbconvert
64
+ bitsandbytes==0.45.2
65
+ # via ads-gen (pyproject.toml)
66
+ bleach==6.2.0
67
+ # via nbconvert
68
+ brotli==1.1.0
69
+ # via gradio
70
+ cachetools==6.2.0
71
+ # via google-auth
72
+ certifi==2025.10.5
73
+ # via
74
+ # httpcore
75
+ # httpx
76
+ # requests
77
+ cffi==2.0.0
78
+ # via argon2-cffi-bindings
79
+ charset-normalizer==3.4.3
80
+ # via requests
81
+ click==8.3.0
82
+ # via
83
+ # typer
84
+ # uvicorn
85
+ comm==0.2.3
86
+ # via
87
+ # ipykernel
88
+ # ipywidgets
89
+ # via matplotlib
90
+ cycler==0.12.1
91
+ # via matplotlib
92
+ datasets==2.21.0
93
+ # via ads-gen (pyproject.toml)
94
+ debugpy==1.8.17
95
+ # via ipykernel
96
+ decorator==5.2.1
97
+ # via
98
+ # gcsfs
99
+ # ipython
100
+ defusedxml==0.7.1
101
+ # via nbconvert
102
+ diffusers==0.35.0
103
+ # via ads-gen (pyproject.toml)
104
+ dill==0.3.8
105
+ # via
106
+ # datasets
107
+ # multiprocess
108
+ dm-tree==0.1.9
109
+ # via tensorflow-datasets
110
+ docstring-parser==0.17.0
111
+ # via simple-parsing
112
+ dreamsim==0.2.1
113
+ # via ads-gen (pyproject.toml)
114
+ einops==0.6.1
115
+ # via
116
+ # ads-gen (pyproject.toml)
117
+ # etils
118
+ etils==1.13.0
119
+ # via
120
+ # array-record
121
+ # tensorflow-datasets
122
+ executing==2.2.1
123
+ # via stack-data
124
+ faiss-cpu==1.12.0
125
+ # via ads-gen (pyproject.toml)
126
+ fastapi==0.118.2
127
+ # via gradio
128
+ fastjsonschema==2.21.2
129
+ # via nbformat
130
+ ffmpy==0.6.2
131
+ # via gradio
132
+ filelock==3.20.0
133
+ # via
134
+ # datasets
135
+ # diffusers
136
+ # gdown
137
+ # huggingface-hub
138
+ # torch
139
+ # transformers
140
+ flatbuffers==25.9.23
141
+ # via tensorflow
142
+ fonttools==4.60.1
143
+ # via matplotlib
144
+ fqdn==1.5.1
145
+ # via jsonschema
146
+ frozenlist==1.8.0
147
+ # via
148
+ # aiohttp
149
+ # aiosignal
150
+ fsspec==2023.12.2
151
+ # via
152
+ # datasets
153
+ # etils
154
+ # gcsfs
155
+ # gradio-client
156
+ # huggingface-hub
157
+ # torch
158
+ ftfy==6.3.1
159
+ # via open-clip-torch
160
+ gast==0.6.0
161
+ # via tensorflow
162
+ gcsfs==2023.12.2.post1
163
+ # via ads-gen (pyproject.toml)
164
+ gdown==4.7.3
165
+ # via ads-gen (pyproject.toml)
166
+ google-api-core
167
+ # via
168
+ # google-cloud-core
169
+ # google-cloud-storage
170
+ google-auth==2.41.1
171
+ # via
172
+ # gcsfs
173
+ # google-api-core
174
+ # google-auth-oauthlib
175
+ # google-cloud-core
176
+ # google-cloud-storage
177
+ google-auth-oauthlib==1.2.2
178
+ # via gcsfs
179
+ google-cloud-core==2.4.3
180
+ # via google-cloud-storage
181
+ google-cloud-storage==3.4.1
182
+ # via gcsfs
183
+ google-crc32c==1.7.1
184
+ # via
185
+ # google-cloud-storage
186
+ # google-resumable-media
187
+ google-pasta==0.2.0
188
+ # via tensorflow
189
+ google-resumable-media==2.7.2
190
+ # via google-cloud-storage
191
+ googleapis-common-protos
192
+ # via
193
+ # google-api-core
194
+ # tensorflow-metadata
195
+ gradio==5.49.1
196
+ # via ads-gen (pyproject.toml)
197
+ gradio-client==1.13.3
198
+ # via gradio
199
+ groovy==0.1.2
200
+ # via gradio
201
+ grpcio==1.75.1
202
+ # via
203
+ # tensorboard
204
+ # tensorflow
205
+ h11==0.16.0
206
+ # via
207
+ # httpcore
208
+ # uvicorn
209
+ h5py==3.14.0
210
+ # via
211
+ # keras
212
+ # tensorflow
213
+ hf-xet==1.1.10
214
+ # via huggingface-hub
215
+ httpcore==1.0.9
216
+ # via httpx
217
+ httpx==0.28.1
218
+ # via
219
+ # gradio
220
+ # gradio-client
221
+ # jupyterlab
222
+ # safehttpx
223
+ huggingface-hub==0.35.3
224
+ # via
225
+ # accelerate
226
+ # datasets
227
+ # diffusers
228
+ # gradio
229
+ # gradio-client
230
+ # open-clip-torch
231
+ # peft
232
+ # timm
233
+ # tokenizers
234
+ # transformers
235
+ hydra-core==1.3.2
236
+ # via ads-gen (pyproject.toml)
237
+ idna==3.10
238
+ # via
239
+ # anyio
240
+ # httpx
241
+ # jsonschema
242
+ # requests
243
+ # yarl
244
+ imageio==2.37.0
245
+ # via scikit-image
246
+ immutabledict==4.2.1
247
+ # via tensorflow-datasets
248
+ importlib-metadata==8.7.0
249
+ # via diffusers
250
+ importlib-resources==6.5.2
251
+ # via etils
252
+ ipykernel==6.30.1
253
+ # via
254
+ # jupyter
255
+ # jupyter-console
256
+ # jupyterlab
257
+ ipython
258
+ # via
259
+ # ipykernel
260
+ # ipywidgets
261
+ # jupyter-console
262
+ ipython-pygments-lexers==1.1.1
263
+ # via ipython
264
+ ipywidgets==8.1.7
265
+ # via jupyter
266
+ isoduration==20.11.0
267
+ # via jsonschema
268
+ jedi==0.19.2
269
+ # via ipython
270
+ jinja2==3.1.6
271
+ # via
272
+ # gradio
273
+ # jupyter-server
274
+ # jupyterlab
275
+ # jupyterlab-server
276
+ # nbconvert
277
+ # torch
278
+ joblib==1.5.2
279
+ # via
280
+ # prdc
281
+ # scikit-learn
282
+ json5==0.12.1
283
+ # via jupyterlab-server
284
+ jsonpointer==3.0.0
285
+ # via jsonschema
286
+ jsonschema==4.25.1
287
+ # via
288
+ # jupyter-events
289
+ # jupyterlab-server
290
+ # nbformat
291
+ jsonschema-specifications==2025.9.1
292
+ # via jsonschema
293
+ jupyter==1.1.1
294
+ # via ads-gen (pyproject.toml)
295
+ jupyter-client==8.6.3
296
+ # via
297
+ # ipykernel
298
+ # jupyter-console
299
+ # jupyter-server
300
+ # nbclient
301
+ jupyter-console==6.6.3
302
+ # via jupyter
303
+ jupyter-core==5.8.1
304
+ # via
305
+ # ipykernel
306
+ # jupyter-client
307
+ # jupyter-console
308
+ # jupyter-server
309
+ # jupyterlab
310
+ # nbclient
311
+ # nbconvert
312
+ # nbformat
313
+ jupyter-events==0.12.0
314
+ # via jupyter-server
315
+ jupyter-lsp==2.3.0
316
+ # via jupyterlab
317
+ jupyter-server==2.17.0
318
+ # via
319
+ # jupyter-lsp
320
+ # jupyterlab
321
+ # jupyterlab-server
322
+ # notebook
323
+ # notebook-shim
324
+ jupyter-server-terminals==0.5.3
325
+ # via jupyter-server
326
+ jupyterlab==4.4.9
327
+ # via
328
+ # jupyter
329
+ # notebook
330
+ jupyterlab-pygments==0.3.0
331
+ # via nbconvert
332
+ jupyterlab-server==2.27.3
333
+ # via
334
+ # jupyterlab
335
+ # notebook
336
+ jupyterlab-widgets==3.0.15
337
+ # via ipywidgets
338
+ keras==3.11.3
339
+ # via tensorflow
340
+ kiwisolver==1.4.9
341
+ # via matplotlib
342
+ lark==1.3.0
343
+ # via rfc3987-syntax
344
+ lazy-loader==0.4
345
+ # via scikit-image
346
+ libclang==18.1.1
347
+ # via tensorflow
348
+ markdown==3.9
349
+ # via tensorboard
350
+ markdown-it-py==4.0.0
351
+ # via rich
352
+ markupsafe==3.0.3
353
+ # via
354
+ # gradio
355
+ # jinja2
356
+ # nbconvert
357
+ # werkzeug
358
+ matplotlib==3.10.7
359
+ # via
360
+ # ads-gen (pyproject.toml)
361
+ # seaborn
362
+ matplotlib-inline==0.1.7
363
+ # via
364
+ # ipykernel
365
+ # ipython
366
+ mdurl==0.1.2
367
+ # via markdown-it-py
368
+ mistune==3.1.4
369
+ # via nbconvert
370
+ ml-dtypes==0.5.3
371
+ # via
372
+ # keras
373
+ # tensorflow
374
+ mpmath==1.3.0
375
+ # via sympy
376
+ multidict==6.7.0
377
+ # via
378
+ # aiohttp
379
+ # yarl
380
+ multiprocess==0.70.16
381
+ # via
382
+ # ads-gen (pyproject.toml)
383
+ # datasets
384
+ namex==0.1.0
385
+ # via keras
386
+ nbclient==0.10.2
387
+ # via nbconvert
388
+ nbconvert==7.16.6
389
+ # via
390
+ # jupyter
391
+ # jupyter-server
392
+ nbformat==5.10.4
393
+ # via
394
+ # jupyter-server
395
+ # nbclient
396
+ # nbconvert
397
+ nest-asyncio==1.6.0
398
+ # via ipykernel
399
+ networkx
400
+ # via
401
+ # scikit-image
402
+ # torch
403
+ notebook==7.4.7
404
+ # via jupyter
405
+ notebook-shim==0.2.4
406
+ # via
407
+ # jupyterlab
408
+ # notebook
409
+ numpy==1.26.4
410
+ # via
411
+ # accelerate
412
+ # bitsandbytes
413
+ # contourpy
414
+ # datasets
415
+ # diffusers
416
+ # dm-tree
417
+ # dreamsim
418
+ # etils
419
+ # faiss-cpu
420
+ # gradio
421
+ # h5py
422
+ # imageio
423
+ # keras
424
+ # matplotlib
425
+ # ml-dtypes
426
+ # opencv-python
427
+ # pandas
428
+ # peft
429
+ # prdc
430
+ # pytorch-fid
431
+ # scikit-image
432
+ # scikit-learn
433
+ # scipy
434
+ # seaborn
435
+ # tensorboard
436
+ # tensorflow
437
+ # tensorflow-datasets
438
+ # tifffile
439
+ # torchvision
440
+ # transformers
441
+ nvidia-cublas-cu12==12.6.4.1
442
+ # via
443
+ # nvidia-cudnn-cu12
444
+ # nvidia-cusolver-cu12
445
+ # torch
446
+ nvidia-cuda-cupti-cu12==12.6.80
447
+ # via torch
448
+ nvidia-cuda-nvrtc-cu12==12.6.77
449
+ # via torch
450
+ nvidia-cuda-runtime-cu12==12.6.77
451
+ # via torch
452
+ nvidia-cudnn-cu12==9.5.1.17
453
+ # via torch
454
+ nvidia-cufft-cu12==11.3.0.4
455
+ # via torch
456
+ nvidia-cufile-cu12==1.11.1.6
457
+ # via torch
458
+ nvidia-curand-cu12==10.3.7.77
459
+ # via torch
460
+ nvidia-cusolver-cu12==11.7.1.2
461
+ # via torch
462
+ nvidia-cusparse-cu12==12.5.4.2
463
+ # via
464
+ # nvidia-cusolver-cu12
465
+ # torch
466
+ nvidia-cusparselt-cu12==0.6.3
467
+ # via torch
468
+ nvidia-ml-py==13.580.82
469
+ # via nvitop
470
+ nvidia-nccl-cu12==2.26.2
471
+ # via torch
472
+ nvidia-nvjitlink-cu12==12.6.85
473
+ # via
474
+ # nvidia-cufft-cu12
475
+ # nvidia-cusolver-cu12
476
+ # nvidia-cusparse-cu12
477
+ # torch
478
+ nvidia-nvtx-cu12==12.6.77
479
+ # via torch
480
+ nvitop==1.5.3
481
+ # via ads-gen (pyproject.toml)
482
+ oauthlib==3.3.1
483
+ # via requests-oauthlib
484
+ omegaconf==2.3.0
485
+ # via
486
+ # ads-gen (pyproject.toml)
487
+ # hydra-core
488
+ open-clip-torch==2.32.0
489
+ # via dreamsim
490
+ opencv-python==4.11.0.86
491
+ # via ads-gen (pyproject.toml)
492
+ opt-einsum==3.4.0
493
+ # via tensorflow
494
+ optree==0.17.0
495
+ # via keras
496
+ orjson==3.11.3
497
+ # via gradio
498
+ packaging==25.0
499
+ # via
500
+ # accelerate
501
+ # datasets
502
+ # faiss-cpu
503
+ # gradio
504
+ # gradio-client
505
+ # huggingface-hub
506
+ # hydra-core
507
+ # ipykernel
508
+ # jupyter-events
509
+ # jupyter-server
510
+ # jupyterlab
511
+ # jupyterlab-server
512
+ # keras
513
+ # lazy-loader
514
+ # matplotlib
515
+ # nbconvert
516
+ # peft
517
+ # scikit-image
518
+ # tensorboard
519
+ # tensorflow
520
+ # transformers
521
+ pandas==2.3.3
522
+ # via
523
+ # datasets
524
+ # gradio
525
+ # seaborn
526
+ # torch-tb-profiler
527
+ pandocfilters==1.5.1
528
+ # via nbconvert
529
+ parso==0.8.5
530
+ # via jedi
531
+ peft==0.17.1
532
+ # via
533
+ # ads-gen (pyproject.toml)
534
+ # dreamsim
535
+ pexpect==4.9.0
536
+ # via ipython
537
+ pillow==9.5.0
538
+ # via
539
+ # ads-gen (pyproject.toml)
540
+ # diffusers
541
+ # dreamsim
542
+ # gradio
543
+ # imageio
544
+ # matplotlib
545
+ # pytorch-fid
546
+ # scikit-image
547
+ # tensorboard
548
+ # torchvision
549
+ platformdirs==4.5.0
550
+ # via jupyter-core
551
+ prdc==0.2
552
+ # via ads-gen (pyproject.toml)
553
+ prometheus-client==0.23.1
554
+ # via jupyter-server
555
+ promise==2.3
556
+ # via tensorflow-datasets
557
+ prompt-toolkit==3.0.52
558
+ # via
559
+ # ipython
560
+ # jupyter-console
561
+ propcache==0.4.1
562
+ # via
563
+ # aiohttp
564
+ # yarl
565
+ proto-plus
566
+ # via google-api-core
567
+ protobuf
568
+ # via
569
+ # ads-gen (pyproject.toml)
570
+ # google-api-core
571
+ # googleapis-common-protos
572
+ # proto-plus
573
+ # tensorboard
574
+ # tensorflow
575
+ # tensorflow-datasets
576
+ # tensorflow-metadata
577
+ psutil==7.1.0
578
+ # via
579
+ # accelerate
580
+ # ipykernel
581
+ # nvitop
582
+ # peft
583
+ # tensorflow-datasets
584
+ ptyprocess==0.7.0
585
+ # via
586
+ # pexpect
587
+ # terminado
588
+ pure-eval==0.2.3
589
+ # via stack-data
590
+ pyarrow==21.0.0
591
+ # via
592
+ # datasets
593
+ # tensorflow-datasets
594
+ pyasn1==0.6.1
595
+ # via
596
+ # pyasn1-modules
597
+ # rsa
598
+ pyasn1-modules==0.4.2
599
+ # via google-auth
600
+ pycparser==2.23
601
+ # via cffi
602
+ pydantic==2.11.10
603
+ # via
604
+ # fastapi
605
+ # gradio
606
+ pydantic-core==2.33.2
607
+ # via pydantic
608
+ pydub==0.25.1
609
+ # via gradio
610
+ pygments==2.19.2
611
+ # via
612
+ # ipython
613
+ # ipython-pygments-lexers
614
+ # jupyter-console
615
+ # nbconvert
616
+ # rich
617
+ pyparsing==3.2.5
618
+ # via matplotlib
619
+ pysocks==1.7.1
620
+ # via requests
621
+ python-dateutil==2.9.0.post0
622
+ # via
623
+ # arrow
624
+ # jupyter-client
625
+ # matplotlib
626
+ # pandas
627
+ python-json-logger==2.0.7
628
+ # via
629
+ # ads-gen (pyproject.toml)
630
+ # jupyter-events
631
+ python-multipart==0.0.20
632
+ # via gradio
633
+ pytorch-fid==0.3.0
634
+ # via ads-gen (pyproject.toml)
635
+ pytz==2025.2
636
+ # via pandas
637
+ pyyaml==6.0.3
638
+ # via
639
+ # ads-gen (pyproject.toml)
640
+ # accelerate
641
+ # datasets
642
+ # gradio
643
+ # huggingface-hub
644
+ # jupyter-events
645
+ # omegaconf
646
+ # peft
647
+ # timm
648
+ # transformers
649
+ pyzmq==27.1.0
650
+ # via
651
+ # ipykernel
652
+ # jupyter-client
653
+ # jupyter-console
654
+ # jupyter-server
655
+ referencing==0.36.2
656
+ # via
657
+ # jsonschema
658
+ # jsonschema-specifications
659
+ # jupyter-events
660
+ regex==2025.9.18
661
+ # via
662
+ # diffusers
663
+ # open-clip-torch
664
+ # transformers
665
+ requests==2.32.5
666
+ # via
667
+ # datasets
668
+ # diffusers
669
+ # fsspec
670
+ # gcsfs
671
+ # gdown
672
+ # google-api-core
673
+ # google-cloud-storage
674
+ # huggingface-hub
675
+ # jupyterlab-server
676
+ # requests-oauthlib
677
+ # tensorflow
678
+ # tensorflow-datasets
679
+ # transformers
680
+ requests-oauthlib==2.0.0
681
+ # via google-auth-oauthlib
682
+ rfc3339-validator==0.1.4
683
+ # via
684
+ # jsonschema
685
+ # jupyter-events
686
+ rfc3986-validator==0.1.1
687
+ # via
688
+ # jsonschema
689
+ # jupyter-events
690
+ rfc3987-syntax==1.1.0
691
+ # via jsonschema
692
+ rich==13.9.4
693
+ # via
694
+ # ads-gen (pyproject.toml)
695
+ # keras
696
+ # typer
697
+ rpds-py==0.27.1
698
+ # via
699
+ # jsonschema
700
+ # referencing
701
+ rsa==4.9.1
702
+ # via google-auth
703
+ ruff==0.14.0
704
+ # via gradio
705
+ safehttpx==0.1.6
706
+ # via gradio
707
+ safetensors==0.6.2
708
+ # via
709
+ # accelerate
710
+ # diffusers
711
+ # open-clip-torch
712
+ # peft
713
+ # timm
714
+ # transformers
715
+ scikit-image==0.24.0
716
+ # via ads-gen (pyproject.toml)
717
+ scikit-learn==1.7.2
718
+ # via prdc
719
+ scipy
720
+ # via
721
+ # ads-gen (pyproject.toml)
722
+ # dreamsim
723
+ # prdc
724
+ # pytorch-fid
725
+ # scikit-image
726
+ # scikit-learn
727
+ seaborn==0.12.2
728
+ # via ads-gen (pyproject.toml)
729
+ segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@dca509fe793f601edb92606367a655c15ac00fdf
730
+ # via ads-gen (pyproject.toml)
731
+ semantic-version==2.10.0
732
+ # via gradio
733
+ send2trash==1.8.3
734
+ # via jupyter-server
735
+ sentencepiece==0.2.1
736
+ # via ads-gen (pyproject.toml)
737
+ setuptools==68.2.2
738
+ # via
739
+ # ads-gen (pyproject.toml)
740
+ # jupyterlab
741
+ # tensorboard
742
+ # tensorflow
743
+ # torch
744
+ # triton
745
+ shellingham==1.5.4
746
+ # via typer
747
+ simple-parsing==0.1.7
748
+ # via tensorflow-datasets
749
+ six==1.17.0
750
+ # via
751
+ # astunparse
752
+ # gdown
753
+ # google-pasta
754
+ # promise
755
+ # python-dateutil
756
+ # rfc3339-validator
757
+ # tensorflow
758
+ sniffio==1.3.1
759
+ # via anyio
760
+ soupsieve==2.8
761
+ # via beautifulsoup4
762
+ stack-data==0.6.3
763
+ # via ipython
764
+ starlette==0.48.0
765
+ # via
766
+ # fastapi
767
+ # gradio
768
+ sympy==1.14.0
769
+ # via torch
770
+ tensorboard
771
+ # via
772
+ # ads-gen (pyproject.toml)
773
+ # tensorflow
774
+ # torch-tb-profiler
775
+ tensorboard-data-server==0.7.2
776
+ # via tensorboard
777
+ tensorflow
778
+ # via ads-gen (pyproject.toml)
779
+ tensorflow-datasets
780
+ # via ads-gen (pyproject.toml)
781
+ tensorflow-metadata
782
+ # via tensorflow-datasets
783
+ termcolor==3.1.0
784
+ # via
785
+ # tensorflow
786
+ # tensorflow-datasets
787
+ terminado==0.18.1
788
+ # via
789
+ # jupyter-server
790
+ # jupyter-server-terminals
791
+ threadpoolctl==3.6.0
792
+ # via scikit-learn
793
+ tifffile
794
+ # via scikit-image
795
+ timm==0.9.16
796
+ # via
797
+ # ads-gen (pyproject.toml)
798
+ # dreamsim
799
+ # open-clip-torch
800
+ tinycss2==1.4.0
801
+ # via bleach
802
+ tokenizers==0.22.1
803
+ # via transformers
804
+ toml==0.10.2
805
+ # via tensorflow-datasets
806
+ tomlkit==0.13.3
807
+ # via gradio
808
+ torch==2.7.0
809
+ # via
810
+ # accelerate
811
+ # bitsandbytes
812
+ # dreamsim
813
+ # open-clip-torch
814
+ # peft
815
+ # pytorch-fid
816
+ # timm
817
+ # torchvision
818
+ torch-tb-profiler==0.4.3
819
+ # via ads-gen (pyproject.toml)
820
+ torchao==0.13.0
821
+ # via ads-gen (pyproject.toml)
822
+ torchvision==0.22.0
823
+ # via
824
+ # ads-gen (pyproject.toml)
825
+ # dreamsim
826
+ # open-clip-torch
827
+ # pytorch-fid
828
+ # timm
829
+ tornado==6.5.2
830
+ # via
831
+ # ipykernel
832
+ # jupyter-client
833
+ # jupyter-server
834
+ # jupyterlab
835
+ # notebook
836
+ # terminado
837
+ tqdm==4.67.1
838
+ # via
839
+ # datasets
840
+ # etils
841
+ # gdown
842
+ # huggingface-hub
843
+ # open-clip-torch
844
+ # peft
845
+ # tensorflow-datasets
846
+ # transformers
847
+ traitlets==5.14.3
848
+ # via
849
+ # ipykernel
850
+ # ipython
851
+ # ipywidgets
852
+ # jupyter-client
853
+ # jupyter-console
854
+ # jupyter-core
855
+ # jupyter-events
856
+ # jupyter-server
857
+ # jupyterlab
858
+ # matplotlib-inline
859
+ # nbclient
860
+ # nbconvert
861
+ # nbformat
862
+ transformers==4.57.0
863
+ # via
864
+ # ads-gen (pyproject.toml)
865
+ # dreamsim
866
+ # peft
867
+ triton==3.3.0
868
+ # via
869
+ # ads-gen (pyproject.toml)
870
+ # torch
871
+ typer==0.19.2
872
+ # via gradio
873
+ types-python-dateutil==2.9.0.20251008
874
+ # via arrow
875
+ typing-extensions==4.15.0
876
+ # via
877
+ # aiosignal
878
+ # anyio
879
+ # beautifulsoup4
880
+ # etils
881
+ # fastapi
882
+ # gradio
883
+ # gradio-client
884
+ # grpcio
885
+ # huggingface-hub
886
+ # optree
887
+ # pydantic
888
+ # pydantic-core
889
+ # referencing
890
+ # simple-parsing
891
+ # starlette
892
+ # tensorflow
893
+ # torch
894
+ # typer
895
+ # typing-inspection
896
+ typing-inspection==0.4.2
897
+ # via pydantic
898
+ tzdata==2025.2
899
+ # via pandas
900
+ uri-template==1.3.0
901
+ # via jsonschema
902
+ urllib3==2.5.0
903
+ # via requests
904
+ uvicorn==0.37.0
905
+ # via gradio
906
+ wcwidth==0.2.14
907
+ # via
908
+ # ftfy
909
+ # prompt-toolkit
910
+ webcolors==24.11.1
911
+ # via jsonschema
912
+ webencodings==0.5.1
913
+ # via
914
+ # bleach
915
+ # tinycss2
916
+ websocket-client==1.9.0
917
+ # via jupyter-server
918
+ websockets==15.0.1
919
+ # via gradio-client
920
+ werkzeug==3.1.3
921
+ # via tensorboard
922
+ wheel==0.45.1
923
+ # via astunparse
924
+ widgetsnbextension==4.0.14
925
+ # via ipywidgets
926
+ wrapt==1.17.3
927
+ # via
928
+ # dm-tree
929
+ # tensorflow
930
+ # tensorflow-datasets
931
+ xxhash==3.6.0
932
+ # via datasets
933
+ yarl==1.22.0
934
+ # via aiohttp
935
+ zipp==3.23.0
936
+ # via
937
+ # etils
938
+ # importlib-metadata
uv.lock ADDED
The diff for this file is too large to render. See raw diff