derek tingle commited on
Commit
6062b47
·
1 Parent(s): 56bbe8e

Initial commit

Browse files
Files changed (5) hide show
  1. README.md +52 -13
  2. app.py +1043 -0
  3. fibo_edit_pipeline.py +953 -0
  4. requirements.txt +133 -0
  5. utils.py +113 -0
README.md CHANGED
@@ -1,13 +1,52 @@
1
- ---
2
- title: Fibo Edit Camera Angle
3
- emoji: 📈
4
- colorFrom: blue
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 6.4.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: Camera Angle Control using Fibo Edit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fibo Edit — Camera Angle Control
2
+
3
+ Fibo Edit with Multi-Angle LoRA for precise camera control. Control rotation, tilt, and zoom to generate images from any angle.
4
+
5
+ ## Features
6
+
7
+ - 🎬 Interactive 3D camera control widget
8
+ - 🎨 Multi-angle image generation using Fibo Edit model
9
+ - 📐 Precise control over rotation, tilt, and zoom
10
+ - 🤖 BRIA API integration for structured captions
11
+ - ⚡ GPU-accelerated inference with Spaces GPU
12
+
13
+ ## Setup
14
+
15
+ ### Required Secrets
16
+
17
+ This Space requires the following environment variable to be set as a **HuggingFace Space Secret**:
18
+
19
+ - `BRIA_API_TOKEN` - Your BRIA API token for structured caption generation
20
+
21
+ To add this secret:
22
+ 1. Go to your Space's Settings
23
+ 2. Navigate to "Repository secrets"
24
+ 3. Add a new secret named `BRIA_API_TOKEN` with your API token value
25
+
26
+ ### Hardware Requirements
27
+
28
+ This Space requires a GPU to run. Make sure to configure your Space to use a GPU instance.
29
+
30
+ ## Usage
31
+
32
+ 1. Upload an input image
33
+ 2. Use the 3D camera control or sliders to adjust:
34
+ - **Rotation**: -180° (back) to +180° (back)
35
+ - **Vertical Tilt**: -1 (low angle) to +1 (high angle)
36
+ - **Zoom**: 0 (wide) to 10 (close-up)
37
+ 3. Click "Generate" to create the image from the new camera angle
38
+ 4. View the structured caption from BRIA API in the accordion
39
+
40
+ ## Model Information
41
+
42
+ - **Base Model**: [briaai/FIBO-Edit](https://huggingface.co/briaai/FIBO-Edit)
43
+ - **LoRA**: [briaai/fibo_edit_multi_angle_full_0121_full_1k](https://huggingface.co/briaai/fibo_edit_multi_angle_full_0121_full_1k)
44
+ - **Text Encoder**: SmolLM3
45
+ - **Scheduler**: FlowMatchEulerDiscreteScheduler
46
+
47
+ ## Credits
48
+
49
+ Built with:
50
+ - [Gradio](https://gradio.app/)
51
+ - [Diffusers](https://huggingface.co/docs/diffusers)
52
+ - [BRIA AI](https://bria.ai/)
app.py ADDED
@@ -0,0 +1,1043 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ import random
5
+ import time
6
+ from io import BytesIO
7
+ from typing import Optional, Tuple
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import requests
12
+ import spaces
13
+ import torch
14
+ from PIL import Image
15
+
16
+ from fibo_edit_pipeline import BriaFiboEditPipeline
17
+ from utils import AngleInstruction
18
+
19
+ # --- Configuration ---
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ # Run locally or on HuggingFace Spaces
23
+ RUN_LOCAL = True
24
+
25
+ # Model paths
26
+ BASE_CHECKPOINT = "briaai/FIBO-Edit" # HuggingFace model ID
27
+ LORA_CHECKPOINT = "briaai/fibo_edit_multi_angle_full_0121_full_1k" # HuggingFace LoRA model ID
28
+
29
+ # BRIA API configuration
30
+ BRIA_API_URL = "https://engine.prod.bria-api.com/v2/structured_prompt/generate/pro"
31
+ BRIA_API_TOKEN = os.environ.get("BRIA_API_TOKEN")
32
+
33
+ if not BRIA_API_TOKEN:
34
+ raise ValueError(
35
+ "BRIA_API_TOKEN environment variable is not set. "
36
+ "Please add it as a HuggingFace Space secret."
37
+ )
38
+
39
+ # Generation defaults
40
+ DEFAULT_NUM_INFERENCE_STEPS = 50
41
+ DEFAULT_GUIDANCE_SCALE = 3.5
42
+ DEFAULT_SEED = 100050
43
+
44
+ MAX_SEED = np.iinfo(np.int32).max
45
+
46
+ print("🚀 Starting Fibo Edit Multi-Angle LoRA Gradio App")
47
+ print(f"Device: {device}")
48
+ print(f"Base checkpoint: {BASE_CHECKPOINT}")
49
+ print(f"LoRA checkpoint: {LORA_CHECKPOINT}")
50
+
51
+
52
+ # --- Helper Functions ---
53
+ def load_pipeline_fiboedit(
54
+ checkpoint: str,
55
+ lora_checkpoint: Optional[str] = None,
56
+ lora_scale: Optional[float] = None,
57
+ fuse_lora: bool = True,
58
+ ):
59
+ """
60
+ Load the Fibo Edit pipeline using BriaFiboEditPipeline with optional LoRA weights.
61
+
62
+ Args:
63
+ checkpoint: HuggingFace model ID for base model
64
+ lora_checkpoint: Optional HuggingFace model ID for LoRA weights
65
+ lora_scale: Scale for LoRA weights when fusing (default None = 1.0)
66
+ fuse_lora: Whether to fuse LoRA into base weights (default True)
67
+
68
+ Returns:
69
+ Loaded BriaFiboEditPipeline
70
+ """
71
+ print(f"Loading BriaFiboEditPipeline from {checkpoint}")
72
+ if lora_checkpoint:
73
+ print(f" with LoRA from {lora_checkpoint}")
74
+
75
+ # Load pipeline from HuggingFace
76
+ print("Loading pipeline...")
77
+ pipe = BriaFiboEditPipeline.from_pretrained(
78
+ checkpoint,
79
+ torch_dtype=torch.bfloat16,
80
+ )
81
+ pipe.to("cuda")
82
+ print(f" Pipeline loaded from {checkpoint}")
83
+
84
+ # Load LoRA weights if provided (PEFT format)
85
+ if lora_checkpoint:
86
+ print(f"Loading PEFT LoRA from {lora_checkpoint}...")
87
+ from peft import PeftModel
88
+
89
+ print(" Loading PEFT adapter onto transformer...")
90
+ pipe.transformer = PeftModel.from_pretrained(pipe.transformer, lora_checkpoint)
91
+ print(" PEFT adapter loaded successfully")
92
+
93
+ if fuse_lora:
94
+ print(" Merging LoRA into base weights...")
95
+ if hasattr(pipe.transformer, "merge_and_unload"):
96
+ pipe.transformer = pipe.transformer.merge_and_unload()
97
+ print(" LoRA merged and unloaded")
98
+ else:
99
+ print(" [WARN] transformer.merge_and_unload() not available")
100
+
101
+ print("✅ Pipeline loaded successfully!")
102
+ return pipe
103
+
104
+
105
+ def generate_structured_caption(
106
+ image: Image.Image, prompt: str, seed: int = 1
107
+ ) -> Optional[dict]:
108
+ """Generate structured caption using BRIA API."""
109
+ buffered = BytesIO()
110
+ image.save(buffered, format="PNG")
111
+ image_bytes = base64.b64encode(buffered.getvalue()).decode("utf-8")
112
+
113
+ payload = {
114
+ "seed": seed,
115
+ "sync": True,
116
+ "images": [image_bytes],
117
+ "prompt": prompt,
118
+ }
119
+
120
+ headers = {
121
+ "Content-Type": "application/json",
122
+ "api_token": BRIA_API_TOKEN,
123
+ }
124
+
125
+ max_retries = 3
126
+ for attempt in range(max_retries):
127
+ try:
128
+ response = requests.post(
129
+ BRIA_API_URL, json=payload, headers=headers, timeout=60
130
+ )
131
+ response.raise_for_status()
132
+ data = response.json()
133
+ structured_prompt_str = data["result"]["structured_prompt"]
134
+ return json.loads(structured_prompt_str)
135
+ except Exception as e:
136
+ if attempt == max_retries - 1:
137
+ print(f"Failed to generate structured caption: {e}")
138
+ return None
139
+ time.sleep(3)
140
+
141
+ return None
142
+
143
+
144
+ # --- Model Loading ---
145
+ print("Loading Fibo Edit pipeline...")
146
+
147
+ try:
148
+ pipe = load_pipeline_fiboedit(
149
+ checkpoint=BASE_CHECKPOINT,
150
+ lora_checkpoint=LORA_CHECKPOINT,
151
+ lora_scale=None,
152
+ fuse_lora=True,
153
+ )
154
+
155
+ if torch.cuda.is_available():
156
+ mem_allocated = torch.cuda.memory_allocated(0) / 1024**3
157
+ print(f" GPU memory allocated: {mem_allocated:.2f} GB")
158
+
159
+ except Exception as e:
160
+ print(f"❌ Error loading pipeline: {e}")
161
+ import traceback
162
+
163
+ traceback.print_exc()
164
+ raise
165
+
166
+
167
+ def build_camera_prompt(
168
+ rotate_deg: float = 0.0, zoom: float = 0.0, vertical_tilt: float = 0.0
169
+ ) -> str:
170
+ """Build a natural language camera instruction from parameters."""
171
+ # Create AngleInstruction from camera parameters
172
+ angle_instruction = AngleInstruction.from_camera_params(
173
+ rotation=rotate_deg, tilt=vertical_tilt, zoom=zoom
174
+ )
175
+
176
+ # Generate natural language description
177
+ view_map = {
178
+ "back view": "view from the opposite side",
179
+ "back-left quarter view": "rotate 135 degrees left",
180
+ "back-right quarter view": "rotate 135 degrees right",
181
+ "front view": "keep the front view",
182
+ "front-left quarter view": "rotate 45 degrees left",
183
+ "front-right quarter view": "rotate 45 degrees right",
184
+ "left side view": "rotate 90 degrees left",
185
+ "right side view": "rotate 90 degrees right",
186
+ }
187
+
188
+ shot_map = {
189
+ "elevated shot": "with an elevated viewing angle",
190
+ "eye-level shot": "with an eye-level viewing angle",
191
+ "high-angle shot": "with a high-angle viewing angle",
192
+ "low-angle shot": "with a low-angle viewing angle",
193
+ }
194
+
195
+ zoom_map = {
196
+ "close-up": "and make it a close-up shot",
197
+ "medium shot": "", # Omit medium shot
198
+ "wide shot": "and make it a wide shot",
199
+ }
200
+
201
+ view_text = view_map[angle_instruction.view.value]
202
+ shot_text = shot_map[angle_instruction.shot.value]
203
+ zoom_text = zoom_map[angle_instruction.zoom.value]
204
+
205
+ # Construct the natural language prompt starting with "Change the viewing angle"
206
+ parts = [view_text, shot_text]
207
+ if zoom_text: # Only add zoom if not empty (medium shot is omitted)
208
+ parts.append(zoom_text)
209
+ natural_prompt = "Change the viewing angle: " + ", ".join(parts)
210
+
211
+ return natural_prompt, angle_instruction
212
+
213
+
214
+ def fetch_structured_caption(
215
+ image: Optional[Image.Image] = None,
216
+ rotate_deg: float = 0.0,
217
+ zoom: float = 0.0,
218
+ vertical_tilt: float = 0.0,
219
+ seed: int = 0,
220
+ randomize_seed: bool = True,
221
+ prev_output: Optional[Image.Image] = None,
222
+ ) -> Tuple[int, str, dict, Image.Image]:
223
+ """Fetch structured caption from BRIA API."""
224
+
225
+ # Build natural language prompt and angle instruction
226
+ natural_prompt, angle_instruction = build_camera_prompt(
227
+ rotate_deg, zoom, vertical_tilt
228
+ )
229
+ print(f"Natural Language Prompt: {natural_prompt}")
230
+ print(f"Angle Instruction: {str(angle_instruction)}")
231
+
232
+ if randomize_seed:
233
+ seed = random.randint(0, MAX_SEED)
234
+
235
+ # Get input image
236
+ if image is not None:
237
+ if isinstance(image, Image.Image):
238
+ input_image = image.convert("RGB")
239
+ elif hasattr(image, "name"):
240
+ input_image = Image.open(image.name).convert("RGB")
241
+ else:
242
+ input_image = image
243
+ elif prev_output:
244
+ input_image = prev_output.convert("RGB")
245
+ else:
246
+ raise gr.Error("Please upload an image first.")
247
+
248
+ # Generate structured caption using BRIA API
249
+ print("Generating structured caption from BRIA API...")
250
+ structured_caption = generate_structured_caption(
251
+ input_image, natural_prompt, seed=seed
252
+ )
253
+
254
+ if structured_caption is None:
255
+ raise gr.Error("Failed to generate structured caption from BRIA API")
256
+
257
+ # Replace edit_instruction with angle instruction string
258
+ structured_caption["edit_instruction"] = str(angle_instruction)
259
+
260
+ print(
261
+ f"Structured caption received: {json.dumps(structured_caption, ensure_ascii=False)}"
262
+ )
263
+
264
+ return seed, natural_prompt, structured_caption, input_image
265
+
266
+
267
+ @spaces.GPU
268
+ def generate_image_from_caption(
269
+ input_image: Image.Image,
270
+ structured_caption: dict,
271
+ seed: int,
272
+ guidance_scale: float = 3.5,
273
+ num_inference_steps: int = 50,
274
+ ) -> Image.Image:
275
+ """Generate image using Fibo Edit pipeline with structured caption."""
276
+
277
+ structured_prompt = json.dumps(structured_caption, ensure_ascii=False)
278
+ print("Generating image with structured prompt...")
279
+
280
+ generator = torch.Generator(device=device).manual_seed(seed)
281
+
282
+ result = pipe(
283
+ image=input_image,
284
+ prompt=structured_prompt,
285
+ guidance_scale=guidance_scale,
286
+ num_inference_steps=num_inference_steps,
287
+ generator=generator,
288
+ num_images_per_prompt=1,
289
+ ).images[0]
290
+
291
+ return result
292
+
293
+
294
+ # --- 3D Camera Control Component ---
295
+ # Using gr.HTML directly with templates (Gradio 6 style)
296
+
297
+ CAMERA_3D_HTML_TEMPLATE = """
298
+ <div id="camera-control-wrapper" style="width: 100%; height: 400px; position: relative; background: #1a1a1a; border-radius: 12px; overflow: hidden;">
299
+ <div id="prompt-overlay" style="position: absolute; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.8); padding: 8px 16px; border-radius: 8px; font-family: monospace; font-size: 11px; color: #00ff88; white-space: nowrap; z-index: 10; max-width: 90%; overflow: hidden; text-overflow: ellipsis;"></div>
300
+ <div id="control-legend" style="position: absolute; top: 10px; left: 10px; background: rgba(0,0,0,0.7); padding: 8px 12px; border-radius: 8px; font-family: system-ui; font-size: 11px; color: #fff; z-index: 10;">
301
+ <div style="margin-bottom: 4px;"><span style="color: #00ff88;">●</span> Rotation (↔)</div>
302
+ <div style="margin-bottom: 4px;"><span style="color: #ff69b4;">●</span> Vertical Tilt (↕)</div>
303
+ <div><span style="color: #ffa500;">●</span> Distance/Zoom</div>
304
+ </div>
305
+ </div>
306
+ """
307
+
308
+ CAMERA_3D_JS = """
309
+ (() => {
310
+ const wrapper = element.querySelector('#camera-control-wrapper');
311
+ const promptOverlay = element.querySelector('#prompt-overlay');
312
+
313
+ const initScene = () => {
314
+ if (typeof THREE === 'undefined') {
315
+ setTimeout(initScene, 100);
316
+ return;
317
+ }
318
+
319
+ const scene = new THREE.Scene();
320
+ scene.background = new THREE.Color(0x1a1a1a);
321
+
322
+ const camera = new THREE.PerspectiveCamera(50, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
323
+ camera.position.set(4, 3, 4);
324
+ camera.lookAt(0, 0.75, 0);
325
+
326
+ const renderer = new THREE.WebGLRenderer({ antialias: true });
327
+ renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
328
+ renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
329
+ wrapper.insertBefore(renderer.domElement, wrapper.firstChild);
330
+
331
+ scene.add(new THREE.AmbientLight(0xffffff, 0.6));
332
+ const dirLight = new THREE.DirectionalLight(0xffffff, 0.6);
333
+ dirLight.position.set(5, 10, 5);
334
+ scene.add(dirLight);
335
+
336
+ scene.add(new THREE.GridHelper(6, 12, 0x333333, 0x222222));
337
+
338
+ const CENTER = new THREE.Vector3(0, 0.75, 0);
339
+ const BASE_DISTANCE = 2.0;
340
+ const ROTATION_RADIUS = 2.2;
341
+ const TILT_RADIUS = 1.6;
342
+
343
+ let rotateDeg = props.value?.rotate_deg || 0;
344
+ let zoom = props.value?.zoom || 5.0;
345
+ let verticalTilt = props.value?.vertical_tilt || 0;
346
+
347
+ const rotateSteps = [-180, -135, -90, -45, 0, 45, 90, 135, 180];
348
+ const zoomSteps = [0, 5, 10];
349
+ const tiltSteps = [-1, -0.5, 0, 0.5, 1];
350
+
351
+ function snapToNearest(value, steps) {
352
+ return steps.reduce((prev, curr) => Math.abs(curr - value) < Math.abs(prev - value) ? curr : prev);
353
+ }
354
+
355
+ function createPlaceholderTexture() {
356
+ const canvas = document.createElement('canvas');
357
+ canvas.width = 256;
358
+ canvas.height = 256;
359
+ const ctx = canvas.getContext('2d');
360
+ ctx.fillStyle = '#3a3a4a';
361
+ ctx.fillRect(0, 0, 256, 256);
362
+ ctx.fillStyle = '#ffcc99';
363
+ ctx.beginPath();
364
+ ctx.arc(128, 128, 80, 0, Math.PI * 2);
365
+ ctx.fill();
366
+ ctx.fillStyle = '#333';
367
+ ctx.beginPath();
368
+ ctx.arc(100, 110, 10, 0, Math.PI * 2);
369
+ ctx.arc(156, 110, 10, 0, Math.PI * 2);
370
+ ctx.fill();
371
+ ctx.strokeStyle = '#333';
372
+ ctx.lineWidth = 3;
373
+ ctx.beginPath();
374
+ ctx.arc(128, 130, 35, 0.2, Math.PI - 0.2);
375
+ ctx.stroke();
376
+ return new THREE.CanvasTexture(canvas);
377
+ }
378
+
379
+ let currentTexture = createPlaceholderTexture();
380
+ const planeMaterial = new THREE.MeshBasicMaterial({ map: currentTexture, side: THREE.DoubleSide });
381
+ let targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
382
+ targetPlane.position.copy(CENTER);
383
+ scene.add(targetPlane);
384
+
385
+ function updateTextureFromUrl(url) {
386
+ if (!url) {
387
+ planeMaterial.map = createPlaceholderTexture();
388
+ planeMaterial.needsUpdate = true;
389
+ scene.remove(targetPlane);
390
+ targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
391
+ targetPlane.position.copy(CENTER);
392
+ scene.add(targetPlane);
393
+ return;
394
+ }
395
+
396
+ const loader = new THREE.TextureLoader();
397
+ loader.crossOrigin = 'anonymous';
398
+ loader.load(url, (texture) => {
399
+ texture.minFilter = THREE.LinearFilter;
400
+ texture.magFilter = THREE.LinearFilter;
401
+ planeMaterial.map = texture;
402
+ planeMaterial.needsUpdate = true;
403
+
404
+ const img = texture.image;
405
+ if (img && img.width && img.height) {
406
+ const aspect = img.width / img.height;
407
+ const maxSize = 1.4;
408
+ let planeWidth, planeHeight;
409
+ if (aspect > 1) {
410
+ planeWidth = maxSize;
411
+ planeHeight = maxSize / aspect;
412
+ } else {
413
+ planeHeight = maxSize;
414
+ planeWidth = maxSize * aspect;
415
+ }
416
+ scene.remove(targetPlane);
417
+ targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(planeWidth, planeHeight), planeMaterial);
418
+ targetPlane.position.copy(CENTER);
419
+ scene.add(targetPlane);
420
+ }
421
+ });
422
+ }
423
+
424
+ if (props.imageUrl) {
425
+ updateTextureFromUrl(props.imageUrl);
426
+ }
427
+
428
+ const cameraGroup = new THREE.Group();
429
+ const bodyMat = new THREE.MeshStandardMaterial({ color: 0x6699cc, metalness: 0.5, roughness: 0.3 });
430
+ const body = new THREE.Mesh(new THREE.BoxGeometry(0.28, 0.2, 0.35), bodyMat);
431
+ cameraGroup.add(body);
432
+ const lens = new THREE.Mesh(
433
+ new THREE.CylinderGeometry(0.08, 0.1, 0.16, 16),
434
+ new THREE.MeshStandardMaterial({ color: 0x6699cc, metalness: 0.5, roughness: 0.3 })
435
+ );
436
+ lens.rotation.x = Math.PI / 2;
437
+ lens.position.z = 0.24;
438
+ cameraGroup.add(lens);
439
+ scene.add(cameraGroup);
440
+
441
+ const rotationArcPoints = [];
442
+ for (let i = 0; i <= 64; i++) {
443
+ const angle = THREE.MathUtils.degToRad((360 * i / 64));
444
+ rotationArcPoints.push(new THREE.Vector3(ROTATION_RADIUS * Math.sin(angle), 0.05, ROTATION_RADIUS * Math.cos(angle)));
445
+ }
446
+ const rotationCurve = new THREE.CatmullRomCurve3(rotationArcPoints);
447
+ const rotationArc = new THREE.Mesh(
448
+ new THREE.TubeGeometry(rotationCurve, 64, 0.035, 8, true),
449
+ new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.3 })
450
+ );
451
+ scene.add(rotationArc);
452
+
453
+ const rotationHandle = new THREE.Mesh(
454
+ new THREE.SphereGeometry(0.16, 16, 16),
455
+ new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.5 })
456
+ );
457
+ rotationHandle.userData.type = 'rotation';
458
+ scene.add(rotationHandle);
459
+
460
+ const tiltArcPoints = [];
461
+ for (let i = 0; i <= 32; i++) {
462
+ const angle = THREE.MathUtils.degToRad(-45 + (90 * i / 32));
463
+ tiltArcPoints.push(new THREE.Vector3(-0.7, TILT_RADIUS * Math.sin(angle) + CENTER.y, TILT_RADIUS * Math.cos(angle)));
464
+ }
465
+ const tiltCurve = new THREE.CatmullRomCurve3(tiltArcPoints);
466
+ const tiltArc = new THREE.Mesh(
467
+ new THREE.TubeGeometry(tiltCurve, 32, 0.035, 8, false),
468
+ new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.3 })
469
+ );
470
+ scene.add(tiltArc);
471
+
472
+ const tiltHandle = new THREE.Mesh(
473
+ new THREE.SphereGeometry(0.16, 16, 16),
474
+ new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.5 })
475
+ );
476
+ tiltHandle.userData.type = 'tilt';
477
+ scene.add(tiltHandle);
478
+
479
+ const distanceLineGeo = new THREE.BufferGeometry();
480
+ const distanceLine = new THREE.Line(distanceLineGeo, new THREE.LineBasicMaterial({ color: 0xffa500 }));
481
+ scene.add(distanceLine);
482
+
483
+ const distanceHandle = new THREE.Mesh(
484
+ new THREE.SphereGeometry(0.16, 16, 16),
485
+ new THREE.MeshStandardMaterial({ color: 0xffa500, emissive: 0xffa500, emissiveIntensity: 0.5 })
486
+ );
487
+ distanceHandle.userData.type = 'distance';
488
+ scene.add(distanceHandle);
489
+
490
+ function buildPromptText(rot, zoomVal, tilt) {
491
+ const parts = [];
492
+ if (rot !== 0) {
493
+ const dir = rot > 0 ? 'right' : 'left';
494
+ parts.push('Rotate ' + Math.abs(rot) + '° ' + dir);
495
+ }
496
+ if (zoomVal >= 6.66) parts.push('Close-up');
497
+ else if (zoomVal >= 3.33) parts.push('Medium shot');
498
+ else parts.push('Wide angle');
499
+ if (tilt >= 0.66) parts.push("High angle");
500
+ else if (tilt >= 0.33) parts.push("Elevated");
501
+ else if (tilt <= -0.33) parts.push("Low angle");
502
+ else parts.push("Eye level");
503
+ return parts.length > 0 ? parts.join(' • ') : 'No camera movement';
504
+ }
505
+
506
+ function updatePositions() {
507
+ const rotRad = THREE.MathUtils.degToRad(rotateDeg);
508
+ // Map zoom 0-10 to distance: zoom 0 = far (3.0), zoom 10 = close (1.0)
509
+ const distance = 3.0 - (zoom / 10) * 2.0;
510
+ const tiltAngle = verticalTilt * 35;
511
+ const tiltRad = THREE.MathUtils.degToRad(tiltAngle);
512
+
513
+ const camX = distance * Math.sin(rotRad) * Math.cos(tiltRad);
514
+ const camY = distance * Math.sin(tiltRad) + CENTER.y;
515
+ const camZ = distance * Math.cos(rotRad) * Math.cos(tiltRad);
516
+
517
+ cameraGroup.position.set(camX, camY, camZ);
518
+ cameraGroup.lookAt(CENTER);
519
+
520
+ rotationHandle.position.set(ROTATION_RADIUS * Math.sin(rotRad), 0.05, ROTATION_RADIUS * Math.cos(rotRad));
521
+
522
+ const tiltHandleAngle = THREE.MathUtils.degToRad(tiltAngle);
523
+ tiltHandle.position.set(-0.7, TILT_RADIUS * Math.sin(tiltHandleAngle) + CENTER.y, TILT_RADIUS * Math.cos(tiltHandleAngle));
524
+
525
+ const handleDist = distance - 0.4;
526
+ distanceHandle.position.set(
527
+ handleDist * Math.sin(rotRad) * Math.cos(tiltRad),
528
+ handleDist * Math.sin(tiltRad) + CENTER.y,
529
+ handleDist * Math.cos(rotRad) * Math.cos(tiltRad)
530
+ );
531
+ distanceLineGeo.setFromPoints([cameraGroup.position.clone(), CENTER.clone()]);
532
+
533
+ promptOverlay.textContent = buildPromptText(rotateDeg, zoom, verticalTilt);
534
+ }
535
+
536
+ function updatePropsAndTrigger() {
537
+ const rotSnap = snapToNearest(rotateDeg, rotateSteps);
538
+ const zoomSnap = snapToNearest(zoom, zoomSteps);
539
+ const tiltSnap = snapToNearest(verticalTilt, tiltSteps);
540
+
541
+ props.value = { rotate_deg: rotSnap, zoom: zoomSnap, vertical_tilt: tiltSnap };
542
+ trigger('change', props.value);
543
+ }
544
+
545
+ const raycaster = new THREE.Raycaster();
546
+ const mouse = new THREE.Vector2();
547
+ let isDragging = false;
548
+ let dragTarget = null;
549
+ let dragStartMouse = new THREE.Vector2();
550
+ let dragStartZoom = 0;
551
+ const intersection = new THREE.Vector3();
552
+
553
+ const canvas = renderer.domElement;
554
+
555
+ canvas.addEventListener('mousedown', (e) => {
556
+ const rect = canvas.getBoundingClientRect();
557
+ mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
558
+ mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
559
+
560
+ raycaster.setFromCamera(mouse, camera);
561
+ const intersects = raycaster.intersectObjects([rotationHandle, tiltHandle, distanceHandle]);
562
+
563
+ if (intersects.length > 0) {
564
+ isDragging = true;
565
+ dragTarget = intersects[0].object;
566
+ dragTarget.material.emissiveIntensity = 1.0;
567
+ dragTarget.scale.setScalar(1.3);
568
+ dragStartMouse.copy(mouse);
569
+ dragStartZoom = zoom;
570
+ canvas.style.cursor = 'grabbing';
571
+ }
572
+ });
573
+
574
+ canvas.addEventListener('mousemove', (e) => {
575
+ const rect = canvas.getBoundingClientRect();
576
+ mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
577
+ mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
578
+
579
+ if (isDragging && dragTarget) {
580
+ raycaster.setFromCamera(mouse, camera);
581
+
582
+ if (dragTarget.userData.type === 'rotation') {
583
+ const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
584
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
585
+ let angle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
586
+ rotateDeg = THREE.MathUtils.clamp(angle, -180, 180);
587
+ }
588
+ } else if (dragTarget.userData.type === 'tilt') {
589
+ const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), 0.7);
590
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
591
+ const relY = intersection.y - CENTER.y;
592
+ const relZ = intersection.z;
593
+ const angle = THREE.MathUtils.radToDeg(Math.atan2(relY, relZ));
594
+ verticalTilt = THREE.MathUtils.clamp(angle / 35, -1, 1);
595
+ }
596
+ } else if (dragTarget.userData.type === 'distance') {
597
+ const deltaY = mouse.y - dragStartMouse.y;
598
+ zoom = THREE.MathUtils.clamp(dragStartZoom + deltaY * 20, 0, 10);
599
+ }
600
+ updatePositions();
601
+ } else {
602
+ raycaster.setFromCamera(mouse, camera);
603
+ const intersects = raycaster.intersectObjects([rotationHandle, tiltHandle, distanceHandle]);
604
+ [rotationHandle, tiltHandle, distanceHandle].forEach(h => {
605
+ h.material.emissiveIntensity = 0.5;
606
+ h.scale.setScalar(1);
607
+ });
608
+ if (intersects.length > 0) {
609
+ intersects[0].object.material.emissiveIntensity = 0.8;
610
+ intersects[0].object.scale.setScalar(1.1);
611
+ canvas.style.cursor = 'grab';
612
+ } else {
613
+ canvas.style.cursor = 'default';
614
+ }
615
+ }
616
+ });
617
+
618
+ const onMouseUp = () => {
619
+ if (dragTarget) {
620
+ dragTarget.material.emissiveIntensity = 0.5;
621
+ dragTarget.scale.setScalar(1);
622
+
623
+ const targetRot = snapToNearest(rotateDeg, rotateSteps);
624
+ const targetZoom = snapToNearest(zoom, zoomSteps);
625
+ const targetTilt = snapToNearest(verticalTilt, tiltSteps);
626
+
627
+ const startRot = rotateDeg, startZoom = zoom, startTilt = verticalTilt;
628
+ const startTime = Date.now();
629
+
630
+ function animateSnap() {
631
+ const t = Math.min((Date.now() - startTime) / 200, 1);
632
+ const ease = 1 - Math.pow(1 - t, 3);
633
+
634
+ rotateDeg = startRot + (targetRot - startRot) * ease;
635
+ zoom = startZoom + (targetZoom - startZoom) * ease;
636
+ verticalTilt = startTilt + (targetTilt - startTilt) * ease;
637
+
638
+ updatePositions();
639
+ if (t < 1) requestAnimationFrame(animateSnap);
640
+ else updatePropsAndTrigger();
641
+ }
642
+ animateSnap();
643
+ }
644
+ isDragging = false;
645
+ dragTarget = null;
646
+ canvas.style.cursor = 'default';
647
+ };
648
+
649
+ canvas.addEventListener('mouseup', onMouseUp);
650
+ canvas.addEventListener('mouseleave', onMouseUp);
651
+
652
+ canvas.addEventListener('touchstart', (e) => {
653
+ e.preventDefault();
654
+ const touch = e.touches[0];
655
+ const rect = canvas.getBoundingClientRect();
656
+ mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
657
+ mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
658
+
659
+ raycaster.setFromCamera(mouse, camera);
660
+ const intersects = raycaster.intersectObjects([rotationHandle, tiltHandle, distanceHandle]);
661
+
662
+ if (intersects.length > 0) {
663
+ isDragging = true;
664
+ dragTarget = intersects[0].object;
665
+ dragTarget.material.emissiveIntensity = 1.0;
666
+ dragTarget.scale.setScalar(1.3);
667
+ dragStartMouse.copy(mouse);
668
+ dragStartZoom = zoom;
669
+ }
670
+ }, { passive: false });
671
+
672
+ canvas.addEventListener('touchmove', (e) => {
673
+ e.preventDefault();
674
+ const touch = e.touches[0];
675
+ const rect = canvas.getBoundingClientRect();
676
+ mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
677
+ mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
678
+
679
+ if (isDragging && dragTarget) {
680
+ raycaster.setFromCamera(mouse, camera);
681
+
682
+ if (dragTarget.userData.type === 'rotation') {
683
+ const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
684
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
685
+ let angle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
686
+ rotateDeg = THREE.MathUtils.clamp(angle, -180, 180);
687
+ }
688
+ } else if (dragTarget.userData.type === 'tilt') {
689
+ const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), 0.7);
690
+ if (raycaster.ray.intersectPlane(plane, intersection)) {
691
+ const relY = intersection.y - CENTER.y;
692
+ const relZ = intersection.z;
693
+ const angle = THREE.MathUtils.radToDeg(Math.atan2(relY, relZ));
694
+ verticalTilt = THREE.MathUtils.clamp(angle / 35, -1, 1);
695
+ }
696
+ } else if (dragTarget.userData.type === 'distance') {
697
+ const deltaY = mouse.y - dragStartMouse.y;
698
+ zoom = THREE.MathUtils.clamp(dragStartZoom + deltaY * 20, 0, 10);
699
+ }
700
+ updatePositions();
701
+ }
702
+ }, { passive: false });
703
+
704
+ canvas.addEventListener('touchend', (e) => { e.preventDefault(); onMouseUp(); }, { passive: false });
705
+ canvas.addEventListener('touchcancel', (e) => { e.preventDefault(); onMouseUp(); }, { passive: false });
706
+
707
+ updatePositions();
708
+
709
+ function render() {
710
+ requestAnimationFrame(render);
711
+ renderer.render(scene, camera);
712
+ }
713
+ render();
714
+
715
+ new ResizeObserver(() => {
716
+ camera.aspect = wrapper.clientWidth / wrapper.clientHeight;
717
+ camera.updateProjectionMatrix();
718
+ renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
719
+ }).observe(wrapper);
720
+
721
+ wrapper._updateTexture = updateTextureFromUrl;
722
+
723
+ let lastImageUrl = props.imageUrl;
724
+ let lastValue = JSON.stringify(props.value);
725
+ setInterval(() => {
726
+ if (props.imageUrl !== lastImageUrl) {
727
+ lastImageUrl = props.imageUrl;
728
+ updateTextureFromUrl(props.imageUrl);
729
+ }
730
+ const currentValue = JSON.stringify(props.value);
731
+ if (currentValue !== lastValue) {
732
+ lastValue = currentValue;
733
+ if (props.value && typeof props.value === 'object') {
734
+ rotateDeg = props.value.rotate_deg ?? rotateDeg;
735
+ zoom = props.value.zoom ?? zoom;
736
+ verticalTilt = props.value.vertical_tilt ?? verticalTilt;
737
+ updatePositions();
738
+ }
739
+ }
740
+ }, 100);
741
+ };
742
+
743
+ initScene();
744
+ })();
745
+ """
746
+
747
+
748
+ def create_camera_3d_component(value=None, imageUrl=None, **kwargs):
749
+ """Create a 3D camera control component using gr.HTML."""
750
+ if value is None:
751
+ value = {"rotate_deg": 0, "zoom": 5.0, "vertical_tilt": 0}
752
+
753
+ return gr.HTML(
754
+ value=value,
755
+ html_template=CAMERA_3D_HTML_TEMPLATE,
756
+ js_on_load=CAMERA_3D_JS,
757
+ imageUrl=imageUrl,
758
+ **kwargs,
759
+ )
760
+
761
+
762
+ # --- UI ---
763
+ css = """
764
+ #col-container { max-width: 1100px; margin: 0 auto; }
765
+ .dark .progress-text { color: white !important; }
766
+ #camera-3d-control { min-height: 400px; }
767
+ #examples { max-width: 1100px; margin: 0 auto; }
768
+ .fillable{max-width: 1250px !important}
769
+ """
770
+
771
+
772
+ def reset_all() -> list:
773
+ """Reset all camera control knobs and flags to their default values."""
774
+ return [0, 5.0, 0, True] # rotate_deg, zoom, vertical_tilt, is_reset
775
+
776
+
777
+ def end_reset() -> bool:
778
+ """Mark the end of a reset cycle."""
779
+ return False
780
+
781
+
782
+ def update_dimensions_on_upload(image: Optional[Image.Image]) -> Tuple[int, int]:
783
+ """Compute recommended (width, height) for the output resolution."""
784
+ if image is None:
785
+ return 1024, 1024
786
+
787
+ original_width, original_height = image.size
788
+
789
+ if original_width > original_height:
790
+ new_width = 1024
791
+ aspect_ratio = original_height / original_width
792
+ new_height = int(new_width * aspect_ratio)
793
+ else:
794
+ new_height = 1024
795
+ aspect_ratio = original_width / original_height
796
+ new_width = int(new_height * aspect_ratio)
797
+
798
+ new_width = (new_width // 8) * 8
799
+ new_height = (new_height // 8) * 8
800
+
801
+ return new_width, new_height
802
+
803
+
804
+ with gr.Blocks(css=css, theme=gr.themes.Citrus()) as demo:
805
+ gr.Markdown("""
806
+ ## 🎬 Fibo Edit — Camera Angle Control
807
+
808
+ Fibo Edit with Multi-Angle LoRA for precise camera control ✨
809
+ Control rotation, tilt, and zoom to generate images from any angle 🎥
810
+ """)
811
+
812
+ with gr.Row():
813
+ with gr.Column(scale=1):
814
+ image = gr.Image(label="Input Image", type="pil", height=280)
815
+ prev_output = gr.Image(value=None, visible=False)
816
+ is_reset = gr.Checkbox(value=False, visible=False)
817
+ # Hidden state to pass processed image between steps
818
+ processed_image = gr.State(None)
819
+
820
+ gr.Markdown("### 🎮 3D Camera Control")
821
+
822
+ camera_3d = create_camera_3d_component(
823
+ value={"rotate_deg": 0, "zoom": 5.0, "vertical_tilt": 0},
824
+ elem_id="camera-3d-control",
825
+ )
826
+
827
+ with gr.Row():
828
+ reset_btn = gr.Button("🔄 Reset", size="sm")
829
+ run_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
830
+
831
+ with gr.Column(scale=1):
832
+ result = gr.Image(label="Output Image", interactive=False, height=350)
833
+
834
+ gr.Markdown("### 🎚️ Slider Controls")
835
+
836
+ rotate_deg = gr.Slider(
837
+ label="Horizontal Rotation (°)",
838
+ minimum=-180,
839
+ maximum=180,
840
+ step=45,
841
+ value=0,
842
+ info="-180/180: back, -90: left, 0: front, 90: right",
843
+ )
844
+ zoom = gr.Slider(
845
+ label="Zoom Level",
846
+ minimum=0,
847
+ maximum=10,
848
+ step=1,
849
+ value=5.0,
850
+ info="0-3.33: wide, 3.33-6.66: medium, 6.66-10: close-up",
851
+ )
852
+ vertical_tilt = gr.Slider(
853
+ label="Vertical Tilt",
854
+ minimum=-1,
855
+ maximum=1,
856
+ step=0.5,
857
+ value=0,
858
+ info="-1: low-angle, 0: eye-level, 1: high-angle",
859
+ )
860
+
861
+ prompt_preview = gr.Textbox(label="Generated Prompt", interactive=False)
862
+
863
+ with gr.Accordion("📋 Structured Caption (BRIA API)", open=False):
864
+ structured_json = gr.JSON(label="JSON Response", container=False)
865
+
866
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
867
+ seed = gr.Slider(
868
+ label="Seed",
869
+ minimum=0,
870
+ maximum=MAX_SEED,
871
+ step=1,
872
+ value=DEFAULT_SEED,
873
+ )
874
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
875
+ guidance_scale = gr.Slider(
876
+ label="Guidance Scale",
877
+ minimum=1.0,
878
+ maximum=10.0,
879
+ step=0.1,
880
+ value=DEFAULT_GUIDANCE_SCALE,
881
+ )
882
+ num_inference_steps = gr.Slider(
883
+ label="Inference Steps",
884
+ minimum=1,
885
+ maximum=100,
886
+ step=1,
887
+ value=DEFAULT_NUM_INFERENCE_STEPS,
888
+ )
889
+ height = gr.Slider(
890
+ label="Height", minimum=256, maximum=2048, step=8, value=1024
891
+ )
892
+ width = gr.Slider(
893
+ label="Width", minimum=256, maximum=2048, step=8, value=1024
894
+ )
895
+
896
+ # --- Helper Functions ---
897
+ def update_prompt_from_sliders(rotate, zoom_val, tilt):
898
+ prompt, _ = build_camera_prompt(rotate, zoom_val, tilt)
899
+ return prompt
900
+
901
+ def sync_3d_to_sliders(camera_value):
902
+ if camera_value and isinstance(camera_value, dict):
903
+ rot = camera_value.get("rotate_deg", 0)
904
+ zoom_val = camera_value.get("zoom", 5.0)
905
+ tilt = camera_value.get("vertical_tilt", 0)
906
+ prompt, _ = build_camera_prompt(rot, zoom_val, tilt)
907
+ return rot, zoom_val, tilt, prompt
908
+ return gr.update(), gr.update(), gr.update(), gr.update()
909
+
910
+ def sync_sliders_to_3d(rotate, zoom_val, tilt):
911
+ return {"rotate_deg": rotate, "zoom": zoom_val, "vertical_tilt": tilt}
912
+
913
+ def update_3d_image(img):
914
+ if img is None:
915
+ return gr.update(imageUrl=None)
916
+ buffered = BytesIO()
917
+ img.save(buffered, format="PNG")
918
+ img_str = base64.b64encode(buffered.getvalue()).decode()
919
+ data_url = f"data:image/png;base64,{img_str}"
920
+ return gr.update(imageUrl=data_url)
921
+
922
+ # --- Event Handlers ---
923
+
924
+ # Slider -> Prompt preview
925
+ for slider in [rotate_deg, zoom, vertical_tilt]:
926
+ slider.change(
927
+ fn=update_prompt_from_sliders,
928
+ inputs=[rotate_deg, zoom, vertical_tilt],
929
+ outputs=[prompt_preview],
930
+ )
931
+
932
+ # 3D control -> Sliders + Prompt (no auto-inference)
933
+ camera_3d.change(
934
+ fn=sync_3d_to_sliders,
935
+ inputs=[camera_3d],
936
+ outputs=[rotate_deg, zoom, vertical_tilt, prompt_preview],
937
+ )
938
+
939
+ # Sliders -> 3D control (no auto-inference)
940
+ for slider in [rotate_deg, zoom, vertical_tilt]:
941
+ slider.release(
942
+ fn=sync_sliders_to_3d,
943
+ inputs=[rotate_deg, zoom, vertical_tilt],
944
+ outputs=[camera_3d],
945
+ )
946
+
947
+ # Reset
948
+ reset_btn.click(
949
+ fn=reset_all,
950
+ inputs=None,
951
+ outputs=[rotate_deg, zoom, vertical_tilt, is_reset],
952
+ queue=False,
953
+ ).then(fn=end_reset, inputs=None, outputs=[is_reset], queue=False).then(
954
+ fn=sync_sliders_to_3d,
955
+ inputs=[rotate_deg, zoom, vertical_tilt],
956
+ outputs=[camera_3d],
957
+ )
958
+
959
+ # Generate button - Two-stage process
960
+ # Stage 1: Fetch structured caption from BRIA API and display it immediately
961
+ run_event = run_btn.click(
962
+ fn=fetch_structured_caption,
963
+ inputs=[
964
+ image,
965
+ rotate_deg,
966
+ zoom,
967
+ vertical_tilt,
968
+ seed,
969
+ randomize_seed,
970
+ prev_output,
971
+ ],
972
+ outputs=[seed, prompt_preview, structured_json, processed_image],
973
+ ).then(
974
+ # Stage 2: Generate image with Fibo Edit pipeline
975
+ fn=generate_image_from_caption,
976
+ inputs=[
977
+ processed_image,
978
+ structured_json,
979
+ seed,
980
+ guidance_scale,
981
+ num_inference_steps,
982
+ ],
983
+ outputs=[result],
984
+ )
985
+
986
+ # Image upload
987
+ image.upload(
988
+ fn=update_dimensions_on_upload, inputs=[image], outputs=[width, height]
989
+ ).then(
990
+ fn=reset_all,
991
+ inputs=None,
992
+ outputs=[rotate_deg, zoom, vertical_tilt, is_reset],
993
+ queue=False,
994
+ ).then(fn=end_reset, inputs=None, outputs=[is_reset], queue=False).then(
995
+ fn=update_3d_image, inputs=[image], outputs=[camera_3d]
996
+ )
997
+
998
+ image.clear(fn=lambda: gr.update(imageUrl=None), outputs=[camera_3d])
999
+
1000
+ run_event.then(lambda img, *_: img, inputs=[result], outputs=[prev_output])
1001
+
1002
+ # Examples - Commenting out for now since we need actual example images
1003
+ # Note: With the two-stage inference process, examples would need custom handling
1004
+ # to properly chain fetch_structured_caption -> generate_image_from_caption
1005
+
1006
+ # Sync 3D component when sliders change (covers example loading)
1007
+ def sync_3d_on_slider_change(img, rot, zoom_val, tilt):
1008
+ camera_value = {"rotate_deg": rot, "zoom": zoom_val, "vertical_tilt": tilt}
1009
+ if img is not None:
1010
+ buffered = BytesIO()
1011
+ img.save(buffered, format="PNG")
1012
+ img_str = base64.b64encode(buffered.getvalue()).decode()
1013
+ data_url = f"data:image/png;base64,{img_str}"
1014
+ return gr.update(value=camera_value, imageUrl=data_url)
1015
+ return gr.update(value=camera_value)
1016
+
1017
+ # When any slider value changes (including from examples), sync the 3D component
1018
+ for slider in [rotate_deg, zoom, vertical_tilt]:
1019
+ slider.change(
1020
+ fn=sync_3d_on_slider_change,
1021
+ inputs=[image, rotate_deg, zoom, vertical_tilt],
1022
+ outputs=[camera_3d],
1023
+ )
1024
+
1025
+ # API endpoints for the two-stage inference process
1026
+ gr.api(fetch_structured_caption, api_name="fetch_caption")
1027
+ gr.api(generate_image_from_caption, api_name="generate_image")
1028
+
1029
+ if __name__ == "__main__":
1030
+ head = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
1031
+
1032
+ if RUN_LOCAL:
1033
+ # Local development configuration
1034
+ demo.launch(
1035
+ mcp_server=True,
1036
+ head=head,
1037
+ footer_links=["api", "gradio", "settings"],
1038
+ server_name="0.0.0.0",
1039
+ server_port=8081,
1040
+ )
1041
+ else:
1042
+ # HuggingFace Spaces standard configuration
1043
+ demo.launch(head=head)
fibo_edit_pipeline.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Bria.ai. All rights reserved.
2
+ #
3
+ # This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
4
+ # You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
5
+ #
6
+ # You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
7
+ # indicate if changes were made, and do not use the material for commercial purposes.
8
+ #
9
+ # See the license for further details.
10
+
11
+ from typing import Any, Callable, Dict, List, Optional, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+ from transformers import AutoTokenizer
16
+ from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
17
+ import PIL
18
+ from diffusers.image_processor import VaeImageProcessor
19
+ from diffusers.loaders import FluxLoraLoaderMixin
20
+ from diffusers.models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
21
+ from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
22
+ from diffusers.pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
23
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
26
+ from diffusers.utils import (
27
+ USE_PEFT_BACKEND,
28
+ is_torch_xla_available,
29
+ logging,
30
+ replace_example_docstring,
31
+ scale_lora_layers,
32
+ unscale_lora_layers,
33
+ )
34
+ from diffusers.utils.torch_utils import randn_tensor
35
+
36
+
37
+ if is_torch_xla_available():
38
+ import torch_xla.core.xla_model as xm
39
+
40
+ XLA_AVAILABLE = True
41
+ else:
42
+ XLA_AVAILABLE = False
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+ EXAMPLE_DOC_STRING = """
48
+ Example:
49
+ ```python
50
+ import torch
51
+ from diffusers import BriaFiboPipeline
52
+ from diffusers.modular_pipelines import ModularPipeline
53
+
54
+ torch.set_grad_enabled(False)
55
+ vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
56
+
57
+ pipe = BriaFiboPipeline.from_pretrained(
58
+ "briaai/FIBO",
59
+ trust_remote_code=True,
60
+ torch_dtype=torch.bfloat16,
61
+ )
62
+ pipe.enable_model_cpu_offload()
63
+
64
+ with torch.inference_mode():
65
+ # 1. Create a prompt to generate an initial image
66
+ output = vlm_pipe(prompt="a beautiful dog")
67
+ json_prompt_generate = output.values["json_prompt"]
68
+
69
+ # Generate the image from the structured json prompt
70
+ results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
71
+ results_generate.images[0].save("image_generate.png")
72
+ ```
73
+ """
74
+
75
+ PREFERRED_RESOLUTION = {
76
+ 256 * 256: [(208, 304), (224, 288), (256, 256), (288, 224), (304, 208), (320, 192), (336, 192)],
77
+ 512 * 512: [
78
+ (416, 624),
79
+ (432, 592),
80
+ (464, 560),
81
+ (512, 512),
82
+ (544, 480),
83
+ (576, 448),
84
+ (592, 432),
85
+ (608, 416),
86
+ (624, 416),
87
+ (640, 400),
88
+ (672, 384),
89
+ (704, 368),
90
+ ],
91
+ 1024 * 1024: [
92
+ (832, 1248),
93
+ (880, 1184),
94
+ (912, 1136),
95
+ (1024, 1024),
96
+ (1136, 912),
97
+ (1184, 880),
98
+ (1216, 848),
99
+ (1248, 832),
100
+ (1248, 832),
101
+ (1264, 816),
102
+ (1296, 800),
103
+ (1360, 768),
104
+ ],
105
+ }
106
+
107
+
108
+ class BriaFiboEditPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
109
+ r"""
110
+ Args:
111
+ transformer (`BriaFiboTransformer2DModel`):
112
+ The transformer model for 2D diffusion modeling.
113
+ scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
114
+ Scheduler to be used with `transformer` to denoise the encoded latents.
115
+ vae (`AutoencoderKLWan`):
116
+ Variational Auto-Encoder for encoding and decoding images to and from latent representations.
117
+ text_encoder (`SmolLM3ForCausalLM`):
118
+ Text encoder for processing input prompts.
119
+ tokenizer (`AutoTokenizer`):
120
+ Tokenizer used for processing the input text prompts for the text_encoder.
121
+ """
122
+
123
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
124
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
125
+
126
+ def __init__(
127
+ self,
128
+ transformer: BriaFiboTransformer2DModel,
129
+ scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
130
+ vae: AutoencoderKLWan,
131
+ text_encoder: SmolLM3ForCausalLM,
132
+ tokenizer: AutoTokenizer,
133
+ ):
134
+ self.register_modules(
135
+ vae=vae,
136
+ text_encoder=text_encoder,
137
+ tokenizer=tokenizer,
138
+ transformer=transformer,
139
+ scheduler=scheduler,
140
+ )
141
+
142
+ self.vae_scale_factor = 16
143
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # * 2)
144
+ self.default_sample_size = 32 # 64
145
+
146
+ def get_prompt_embeds(
147
+ self,
148
+ prompt: Union[str, List[str]],
149
+ num_images_per_prompt: int = 1,
150
+ max_sequence_length: int = 2048,
151
+ device: Optional[torch.device] = None,
152
+ dtype: Optional[torch.dtype] = None,
153
+ ):
154
+ device = device or self._execution_device
155
+ dtype = dtype or self.text_encoder.dtype
156
+
157
+ prompt = [prompt] if isinstance(prompt, str) else prompt
158
+ if not prompt:
159
+ raise ValueError("`prompt` must be a non-empty string or list of strings.")
160
+
161
+ batch_size = len(prompt)
162
+ bot_token_id = 128000
163
+
164
+ text_encoder_device = device if device is not None else torch.device("cpu")
165
+ if not isinstance(text_encoder_device, torch.device):
166
+ text_encoder_device = torch.device(text_encoder_device)
167
+
168
+ if all(p == "" for p in prompt):
169
+ input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
170
+ attention_mask = torch.ones_like(input_ids)
171
+ else:
172
+ tokenized = self.tokenizer(
173
+ prompt,
174
+ padding="longest",
175
+ max_length=max_sequence_length,
176
+ truncation=True,
177
+ add_special_tokens=True,
178
+ return_tensors="pt",
179
+ )
180
+ input_ids = tokenized.input_ids.to(text_encoder_device)
181
+ attention_mask = tokenized.attention_mask.to(text_encoder_device)
182
+
183
+ if any(p == "" for p in prompt):
184
+ empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
185
+ input_ids[empty_rows] = bot_token_id
186
+ attention_mask[empty_rows] = 1
187
+
188
+ encoder_outputs = self.text_encoder(
189
+ input_ids,
190
+ attention_mask=attention_mask,
191
+ output_hidden_states=True,
192
+ )
193
+ hidden_states = encoder_outputs.hidden_states
194
+
195
+ prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
196
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
197
+
198
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
199
+ hidden_states = tuple(
200
+ layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
201
+ )
202
+ attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
203
+
204
+ return prompt_embeds, hidden_states, attention_mask
205
+
206
+ @staticmethod
207
+ def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
208
+ # Pad embeddings to `max_tokens` while preserving the mask of real tokens.
209
+ batch_size, seq_len, dim = prompt_embeds.shape
210
+
211
+ if attention_mask is None:
212
+ attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
213
+ else:
214
+ attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
215
+
216
+ if max_tokens < seq_len:
217
+ raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
218
+
219
+ if max_tokens > seq_len:
220
+ pad_length = max_tokens - seq_len
221
+ padding = torch.zeros((batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
222
+ prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
223
+
224
+ mask_padding = torch.zeros((batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
225
+ attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
226
+
227
+ return prompt_embeds, attention_mask
228
+
229
+ def encode_prompt(
230
+ self,
231
+ prompt: Union[str, List[str]],
232
+ device: Optional[torch.device] = None,
233
+ num_images_per_prompt: int = 1,
234
+ guidance_scale: float = 5,
235
+ negative_prompt: Optional[Union[str, List[str]]] = None,
236
+ prompt_embeds: Optional[torch.FloatTensor] = None,
237
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
238
+ max_sequence_length: int = 3000,
239
+ lora_scale: Optional[float] = None,
240
+ ):
241
+ r"""
242
+ Args:
243
+ prompt (`str` or `List[str]`, *optional*):
244
+ prompt to be encoded
245
+ device: (`torch.device`):
246
+ torch device
247
+ num_images_per_prompt (`int`):
248
+ number of images that should be generated per prompt
249
+ guidance_scale (`float`):
250
+ Guidance scale for classifier free guidance.
251
+ negative_prompt (`str` or `List[str]`, *optional*):
252
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
253
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
254
+ less than `1`).
255
+ prompt_embeds (`torch.FloatTensor`, *optional*):
256
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
257
+ provided, text embeddings will be generated from `prompt` input argument.
258
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
259
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
260
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
261
+ argument.
262
+ """
263
+ device = device or self._execution_device
264
+
265
+ # set lora scale so that monkey patched LoRA
266
+ # function of text encoder can correctly access it
267
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
268
+ self._lora_scale = lora_scale
269
+
270
+ # dynamically adjust the LoRA scale
271
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
272
+ scale_lora_layers(self.text_encoder, lora_scale)
273
+
274
+ prompt = [prompt] if isinstance(prompt, str) else prompt
275
+ if prompt is not None:
276
+ batch_size = len(prompt)
277
+ else:
278
+ batch_size = prompt_embeds.shape[0]
279
+
280
+ prompt_attention_mask = None
281
+ negative_prompt_attention_mask = None
282
+ if prompt_embeds is None:
283
+ prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
284
+ prompt=prompt,
285
+ num_images_per_prompt=num_images_per_prompt,
286
+ max_sequence_length=max_sequence_length,
287
+ device=device,
288
+ )
289
+ prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
290
+ prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
291
+
292
+ if guidance_scale > 1:
293
+ if isinstance(negative_prompt, list) and negative_prompt[0] is None:
294
+ negative_prompt = ""
295
+ negative_prompt = negative_prompt or ""
296
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
297
+ if prompt is not None and type(prompt) is not type(negative_prompt):
298
+ raise TypeError(
299
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
300
+ f" {type(prompt)}."
301
+ )
302
+ elif batch_size != len(negative_prompt):
303
+ raise ValueError(
304
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
305
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
306
+ " the batch size of `prompt`."
307
+ )
308
+
309
+ negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
310
+ prompt=negative_prompt,
311
+ num_images_per_prompt=num_images_per_prompt,
312
+ max_sequence_length=max_sequence_length,
313
+ device=device,
314
+ )
315
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
316
+ negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
317
+
318
+ if self.text_encoder is not None:
319
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
320
+ # Retrieve the original scale by scaling back the LoRA layers
321
+ unscale_lora_layers(self.text_encoder, lora_scale)
322
+
323
+ # Pad to longest
324
+ if prompt_attention_mask is not None:
325
+ prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
326
+
327
+ if negative_prompt_embeds is not None:
328
+ if negative_prompt_attention_mask is not None:
329
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(
330
+ device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
331
+ )
332
+ max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
333
+
334
+ prompt_embeds, prompt_attention_mask = self.pad_embedding(
335
+ prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
336
+ )
337
+ prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
338
+
339
+ negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
340
+ negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
341
+ )
342
+ negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
343
+ else:
344
+ max_tokens = prompt_embeds.shape[1]
345
+ prompt_embeds, prompt_attention_mask = self.pad_embedding(
346
+ prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
347
+ )
348
+ negative_prompt_layers = None
349
+
350
+ dtype = self.text_encoder.dtype
351
+ text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
352
+
353
+ return (
354
+ prompt_embeds,
355
+ negative_prompt_embeds,
356
+ text_ids,
357
+ prompt_attention_mask,
358
+ negative_prompt_attention_mask,
359
+ prompt_layers,
360
+ negative_prompt_layers,
361
+ )
362
+
363
+ @property
364
+ def guidance_scale(self):
365
+ return self._guidance_scale
366
+
367
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
368
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
369
+ # corresponds to doing no classifier free guidance.
370
+
371
+ @property
372
+ def joint_attention_kwargs(self):
373
+ return self._joint_attention_kwargs
374
+
375
+ @property
376
+ def num_timesteps(self):
377
+ return self._num_timesteps
378
+
379
+ @property
380
+ def interrupt(self):
381
+ return self._interrupt
382
+
383
+ @staticmethod
384
+ # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
385
+ def _unpack_latents(latents, height, width, vae_scale_factor):
386
+ batch_size, num_patches, channels = latents.shape
387
+
388
+ height = height // vae_scale_factor
389
+ width = width // vae_scale_factor
390
+
391
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
392
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
393
+
394
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
395
+ return latents
396
+
397
+ @staticmethod
398
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
399
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
400
+ latent_image_ids = torch.zeros(height, width, 3)
401
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
402
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
403
+
404
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
405
+
406
+ latent_image_ids = latent_image_ids.reshape(
407
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
408
+ )
409
+
410
+ return latent_image_ids.to(device=device, dtype=dtype)
411
+
412
+ @staticmethod
413
+ def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
414
+ batch_size, num_patches, channels = latents.shape
415
+
416
+ height = height // vae_scale_factor
417
+ width = width // vae_scale_factor
418
+
419
+ latents = latents.view(batch_size, height, width, channels)
420
+ latents = latents.permute(0, 3, 1, 2)
421
+
422
+ return latents
423
+
424
+ @staticmethod
425
+ def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
426
+ latents = latents.permute(0, 2, 3, 1)
427
+ latents = latents.reshape(batch_size, height * width, num_channels_latents)
428
+ return latents
429
+
430
+ @staticmethod
431
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
432
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
433
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
434
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
435
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
436
+
437
+ return latents
438
+
439
+ def prepare_latents(
440
+ self,
441
+ batch_size,
442
+ num_channels_latents,
443
+ height,
444
+ width,
445
+ dtype,
446
+ device,
447
+ generator,
448
+ latents=None,
449
+ do_patching=False,
450
+ ):
451
+ height = int(height) // self.vae_scale_factor
452
+ width = int(width) // self.vae_scale_factor
453
+
454
+ shape = (batch_size, num_channels_latents, height, width)
455
+
456
+ if latents is not None:
457
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
458
+ return latents.to(device=device, dtype=dtype), latent_image_ids
459
+
460
+ if isinstance(generator, list) and len(generator) != batch_size:
461
+ raise ValueError(
462
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
463
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
464
+ )
465
+
466
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
467
+ if do_patching:
468
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
469
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
470
+ else:
471
+ latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
472
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
473
+
474
+ return latents, latent_image_ids
475
+
476
+ @staticmethod
477
+ def _prepare_attention_mask(attention_mask):
478
+ attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
479
+
480
+ # convert to 0 - keep, -inf ignore
481
+ attention_matrix = torch.where(
482
+ attention_matrix == 1, 0.0, -torch.inf
483
+ ) # Apply -inf to ignored tokens for nulling softmax score
484
+ return attention_matrix
485
+
486
+ @torch.no_grad()
487
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
488
+ def __call__(
489
+ self,
490
+ prompt: Union[str, List[str]] = None,
491
+ image: Optional[Union[PIL.Image.Image, torch.FloatTensor]] = None,
492
+ num_inference_steps: int = 30,
493
+ timesteps: List[int] = None,
494
+ guidance_scale: float = 5,
495
+ negative_prompt: Optional[Union[str, List[str]]] = None,
496
+ num_images_per_prompt: Optional[int] = 1,
497
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
498
+ latents: Optional[torch.FloatTensor] = None,
499
+ prompt_embeds: Optional[torch.FloatTensor] = None,
500
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
501
+ output_type: Optional[str] = "pil",
502
+ return_dict: bool = True,
503
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
504
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
505
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
506
+ max_sequence_length: int = 3000,
507
+ do_patching=False,
508
+ _auto_resize: bool = True,
509
+ base_resolution: int = 1024,
510
+ ):
511
+ r"""
512
+ Function invoked when calling the pipeline for generation.
513
+
514
+ Args:
515
+ prompt (`str` or `List[str]`, *optional*):
516
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
517
+ instead.
518
+ image (`PIL.Image.Image` or `torch.FloatTensor`, *optional*):
519
+ The image to guide the image generation. If not defined, the pipeline will generate an image from scratch.
520
+ num_inference_steps (`int`, *optional*, defaults to 50):
521
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
522
+ expense of slower inference.
523
+ timesteps (`List[int]`, *optional*):
524
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
525
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
526
+ passed will be used. Must be in descending order.
527
+ guidance_scale (`float`, *optional*, defaults to 5.0):
528
+ Guidance scale as defined in [Classifier-Free Diffusion
529
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
530
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
531
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
532
+ the text `prompt`, usually at the expense of lower image quality.
533
+ negative_prompt (`str` or `List[str]`, *optional*):
534
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
535
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
536
+ less than `1`).
537
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
538
+ The number of images to generate per prompt.
539
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
540
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
541
+ to make generation deterministic.
542
+ latents (`torch.FloatTensor`, *optional*):
543
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
544
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
545
+ tensor will ge generated by sampling using the supplied random `generator`.
546
+ prompt_embeds (`torch.FloatTensor`, *optional*):
547
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
548
+ provided, text embeddings will be generated from `prompt` input argument.
549
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
550
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
551
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
552
+ argument.
553
+ output_type (`str`, *optional*, defaults to `"pil"`):
554
+ The output format of the generate image. Choose between
555
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
556
+ return_dict (`bool`, *optional*, defaults to `True`):
557
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
558
+ of a plain tuple.
559
+ joint_attention_kwargs (`dict`, *optional*):
560
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
561
+ `self.processor` in
562
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
563
+ callback_on_step_end (`Callable`, *optional*):
564
+ A function that calls at the end of each denoising steps during the inference. The function is called
565
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
566
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
567
+ `callback_on_step_end_tensor_inputs`.
568
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
569
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
570
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
571
+ `._callback_tensor_inputs` attribute of your pipeline class.
572
+ max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
573
+ do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
574
+ Examples:
575
+ Returns:
576
+ [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
577
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
578
+ generated images.
579
+ """
580
+
581
+ if image is not None and _auto_resize:
582
+ image_height, image_width = self.image_processor.get_default_height_width(image)
583
+ # area = min(prefered_resolutions.keys(),key=lambda size: abs(image_height*image_width-size))
584
+ image_width, image_height = min(
585
+ PREFERRED_RESOLUTION[base_resolution * base_resolution],
586
+ key=lambda size: abs(size[0] / size[1] - image_width / image_height),
587
+ )
588
+ width, height = image_width, image_height
589
+
590
+ # 1. Check inputs. Raise error if not correct
591
+ self.check_inputs( # check flux
592
+ prompt=prompt,
593
+ prompt_embeds=prompt_embeds,
594
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
595
+ max_sequence_length=max_sequence_length,
596
+ )
597
+
598
+ self._guidance_scale = guidance_scale
599
+ self._joint_attention_kwargs = joint_attention_kwargs
600
+ self._interrupt = False
601
+
602
+ # 2. Define call parameters
603
+ if prompt is not None and isinstance(prompt, str):
604
+ batch_size = 1
605
+ elif prompt is not None and isinstance(prompt, list):
606
+ batch_size = len(prompt)
607
+ else:
608
+ batch_size = prompt_embeds.shape[0]
609
+
610
+ device = self._execution_device
611
+
612
+ lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
613
+
614
+ (
615
+ prompt_embeds,
616
+ negative_prompt_embeds,
617
+ text_ids,
618
+ prompt_attention_mask,
619
+ negative_prompt_attention_mask,
620
+ prompt_layers,
621
+ negative_prompt_layers,
622
+ ) = self.encode_prompt(
623
+ prompt=prompt,
624
+ negative_prompt=negative_prompt,
625
+ guidance_scale=guidance_scale,
626
+ prompt_embeds=prompt_embeds,
627
+ negative_prompt_embeds=negative_prompt_embeds,
628
+ device=device,
629
+ max_sequence_length=max_sequence_length,
630
+ num_images_per_prompt=num_images_per_prompt,
631
+ lora_scale=lora_scale,
632
+ )
633
+ prompt_batch_size = prompt_embeds.shape[0]
634
+
635
+ if guidance_scale > 1:
636
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
637
+ prompt_layers = [
638
+ torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
639
+ ]
640
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
641
+
642
+ total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
643
+ self.transformer.single_transformer_blocks
644
+ )
645
+ if len(prompt_layers) >= total_num_layers_transformer:
646
+ # remove first layers
647
+ prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
648
+ else:
649
+ # duplicate last layer
650
+ prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
651
+
652
+ # Preprocess image
653
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
654
+ image = self.image_processor.resize(image, height, width)
655
+ image = self.image_processor.preprocess(image, height, width)
656
+
657
+ # 5. Prepare latent variables
658
+ num_channels_latents = self.transformer.config.in_channels
659
+ if do_patching:
660
+ num_channels_latents = int(num_channels_latents / 4)
661
+
662
+ latents, latent_image_ids = self.prepare_latents(
663
+ prompt_batch_size,
664
+ num_channels_latents,
665
+ height,
666
+ width,
667
+ prompt_embeds.dtype,
668
+ device,
669
+ generator,
670
+ latents,
671
+ do_patching,
672
+ )
673
+
674
+ if image is not None:
675
+ image_latents, image_ids = self.prepare_image_latents(
676
+ image=image,
677
+ batch_size=batch_size * num_images_per_prompt,
678
+ num_channels_latents=num_channels_latents,
679
+ height=height,
680
+ width=width,
681
+ dtype=prompt_embeds.dtype,
682
+ device=device,
683
+ generator=generator,
684
+ )
685
+ latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0) # dim 0 is sequence dimension
686
+ else:
687
+ image_latents = None
688
+
689
+ latent_attention_mask = torch.ones(
690
+ [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
691
+ )
692
+ if guidance_scale > 1:
693
+ latent_attention_mask = latent_attention_mask.repeat(2, 1)
694
+
695
+ if image_latents is None:
696
+ attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
697
+ else:
698
+ image_latent_attention_mask = torch.ones(
699
+ [image_latents.shape[0], image_latents.shape[1]],
700
+ dtype=image_latents.dtype,
701
+ device=image_latents.device,
702
+ )
703
+ if guidance_scale > 1:
704
+ image_latent_attention_mask = image_latent_attention_mask.repeat(2, 1)
705
+ attention_mask = torch.cat(
706
+ [prompt_attention_mask, latent_attention_mask, image_latent_attention_mask], dim=1
707
+ )
708
+
709
+ attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq
710
+ attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
711
+
712
+ if self._joint_attention_kwargs is None:
713
+ self._joint_attention_kwargs = {}
714
+ self._joint_attention_kwargs["attention_mask"] = attention_mask
715
+
716
+ # Adapt scheduler to dynamic shifting (resolution dependent)
717
+
718
+ if do_patching:
719
+ seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
720
+ else:
721
+ seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
722
+
723
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
724
+
725
+ mu = calculate_shift(
726
+ seq_len,
727
+ self.scheduler.config.base_image_seq_len,
728
+ self.scheduler.config.max_image_seq_len,
729
+ self.scheduler.config.base_shift,
730
+ self.scheduler.config.max_shift,
731
+ )
732
+
733
+ # Init sigmas and timesteps according to shift size
734
+ # This changes the scheduler in-place according to the dynamic scheduling
735
+ timesteps, num_inference_steps = retrieve_timesteps(
736
+ self.scheduler,
737
+ num_inference_steps=num_inference_steps,
738
+ device=device,
739
+ timesteps=None,
740
+ sigmas=sigmas,
741
+ mu=mu,
742
+ )
743
+
744
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
745
+ self._num_timesteps = len(timesteps)
746
+
747
+ # Support old different diffusers versions
748
+ if len(latent_image_ids.shape) == 3:
749
+ latent_image_ids = latent_image_ids[0]
750
+
751
+ if len(text_ids.shape) == 3:
752
+ text_ids = text_ids[0]
753
+
754
+ # 6. Denoising loop
755
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
756
+ for i, t in enumerate(timesteps):
757
+ if self.interrupt:
758
+ continue
759
+
760
+ latent_model_input = latents
761
+
762
+ if image_latents is not None:
763
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
764
+
765
+ # expand the latents if we are doing classifier free guidance
766
+ latent_model_input = torch.cat([latent_model_input] * 2) if guidance_scale > 1 else latent_model_input
767
+
768
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
769
+ timestep = t.expand(latent_model_input.shape[0]).to(
770
+ device=latent_model_input.device, dtype=latent_model_input.dtype
771
+ )
772
+
773
+ # This is predicts "v" from flow-matching or eps from diffusion
774
+ noise_pred = self.transformer(
775
+ hidden_states=latent_model_input,
776
+ timestep=timestep,
777
+ encoder_hidden_states=prompt_embeds,
778
+ text_encoder_layers=prompt_layers,
779
+ joint_attention_kwargs=self.joint_attention_kwargs,
780
+ return_dict=False,
781
+ txt_ids=text_ids,
782
+ img_ids=latent_image_ids,
783
+ )[0]
784
+
785
+ # perform guidance
786
+ if guidance_scale > 1:
787
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
788
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
789
+
790
+ # compute the previous noisy sample x_t -> x_t-1
791
+ latents_dtype = latents.dtype
792
+ latents = self.scheduler.step(noise_pred[:, : latents.shape[1], ...], t, latents, return_dict=False)[0]
793
+
794
+ if latents.dtype != latents_dtype:
795
+ if torch.backends.mps.is_available():
796
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
797
+ latents = latents.to(latents_dtype)
798
+
799
+ if callback_on_step_end is not None:
800
+ callback_kwargs = {}
801
+ for k in callback_on_step_end_tensor_inputs:
802
+ callback_kwargs[k] = locals()[k]
803
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
804
+
805
+ latents = callback_outputs.pop("latents", latents)
806
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
807
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
808
+
809
+ # call the callback, if provided
810
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
811
+ progress_bar.update()
812
+
813
+ if XLA_AVAILABLE:
814
+ xm.mark_step()
815
+
816
+ if output_type == "latent":
817
+ image = latents
818
+
819
+ else:
820
+ if do_patching:
821
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
822
+ else:
823
+ latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
824
+
825
+ latents = latents.unsqueeze(dim=2)
826
+ latents_device = latents[0].device
827
+ latents_dtype = latents[0].dtype
828
+ latents_mean = (
829
+ torch.tensor(self.vae.config.latents_mean)
830
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
831
+ .to(latents_device, latents_dtype)
832
+ )
833
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
834
+ latents_device, latents_dtype
835
+ )
836
+ latents_scaled = [latent / latents_std + latents_mean for latent in latents]
837
+ latents_scaled = torch.cat(latents_scaled, dim=0)
838
+ image = []
839
+ for scaled_latent in latents_scaled:
840
+ curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
841
+ curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
842
+ image.append(curr_image)
843
+ if len(image) == 1:
844
+ image = image[0]
845
+ else:
846
+ image = np.stack(image, axis=0)
847
+
848
+ # Offload all models
849
+ self.maybe_free_model_hooks()
850
+
851
+ if not return_dict:
852
+ return (image,)
853
+
854
+ return BriaFiboPipelineOutput(images=image)
855
+
856
+ def prepare_image_latents(
857
+ self,
858
+ image: torch.Tensor,
859
+ batch_size: int,
860
+ num_channels_latents: int,
861
+ height: int,
862
+ width: int,
863
+ dtype: torch.dtype,
864
+ device: torch.device,
865
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
866
+ ):
867
+ image = image.to(device=device, dtype=dtype)
868
+
869
+ height = int(height) // self.vae_scale_factor
870
+ width = int(width) // self.vae_scale_factor
871
+
872
+ # scaling
873
+ latents_mean = (
874
+ torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
875
+ )
876
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
877
+ device, dtype
878
+ )
879
+
880
+ image_latents_cthw = self.vae.encode(image.unsqueeze(2)).latent_dist.mean
881
+ latents_scaled = [(latent - latents_mean) * latents_std for latent in image_latents_cthw]
882
+ image_latents_cthw = torch.concat(latents_scaled, dim=0)
883
+ image_latents_bchw = image_latents_cthw[:, :, 0, :, :]
884
+
885
+ image_latent_height, image_latent_width = image_latents_bchw.shape[2:]
886
+ image_latents_bsd = self._pack_latents_no_patch(
887
+ latents=image_latents_bchw,
888
+ batch_size=batch_size,
889
+ num_channels_latents=num_channels_latents,
890
+ height=image_latent_height,
891
+ width=image_latent_width,
892
+ )
893
+ # breakpoint()
894
+ image_ids = self._prepare_latent_image_ids(
895
+ batch_size=batch_size, height=image_latent_height, width=image_latent_width, device=device, dtype=dtype
896
+ )
897
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
898
+ image_ids[..., 0] = 1
899
+ return image_latents_bsd, image_ids
900
+
901
+ def check_inputs(
902
+ self,
903
+ prompt,
904
+ negative_prompt=None,
905
+ prompt_embeds=None,
906
+ negative_prompt_embeds=None,
907
+ callback_on_step_end_tensor_inputs=None,
908
+ max_sequence_length=None,
909
+ ):
910
+ if callback_on_step_end_tensor_inputs is not None and not all(
911
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
912
+ ):
913
+ raise ValueError(
914
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
915
+ )
916
+
917
+ if prompt is not None and prompt_embeds is not None:
918
+ raise ValueError(
919
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
920
+ " only forward one of the two."
921
+ )
922
+ elif prompt is None and prompt_embeds is None:
923
+ raise ValueError(
924
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
925
+ )
926
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
927
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
928
+
929
+ if negative_prompt is not None and negative_prompt_embeds is not None:
930
+ raise ValueError(
931
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
932
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
933
+ )
934
+
935
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
936
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
937
+ raise ValueError(
938
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
939
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
940
+ f" {negative_prompt_embeds.shape}."
941
+ )
942
+
943
+ if max_sequence_length is not None and max_sequence_length > 3000:
944
+ raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
945
+
946
+ def create_attention_matrix(self, attention_mask):
947
+ attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
948
+
949
+ # convert to 0 - keep, -inf ignore
950
+ attention_matrix = torch.where(
951
+ attention_matrix == 1, 0.0, -torch.inf
952
+ ) # Apply -inf to ignored tokens for nulling softmax score
953
+ return attention_matrix
requirements.txt ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.12.0
2
+ aiofiles==24.1.0
3
+ annotated-doc==0.0.4
4
+ annotated-types==0.7.0
5
+ anyio==4.12.1
6
+ asttokens==3.0.1
7
+ attrs==25.4.0
8
+ boto3==1.42.28
9
+ botocore==1.42.28
10
+ brotli==1.2.0
11
+ certifi==2026.1.4
12
+ cffi==2.0.0 ; platform_python_implementation != 'PyPy'
13
+ charset-normalizer==3.4.4
14
+ click==8.3.1
15
+ colorama==0.4.6 ; sys_platform == 'win32'
16
+ cryptography==46.0.3
17
+ cuda-bindings==12.9.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
18
+ cuda-pathfinder==1.3.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
19
+ decorator==5.2.1
20
+ diffusers @ git+https://github.com/huggingface/diffusers@956bdcc3ea4897eaeb6c828b8433bdcae71e9f0f
21
+ einops==0.8.2
22
+ exceptiongroup==1.3.1 ; python_full_version < '3.11'
23
+ executing==2.2.1
24
+ fal-client==0.12.0
25
+ fastapi==0.128.0
26
+ ffmpy==1.0.0
27
+ filelock==3.20.3
28
+ fsspec==2026.1.0
29
+ gradio==6.4.0
30
+ gradio-client==2.0.3
31
+ groovy==0.1.2
32
+ h11==0.16.0
33
+ hf-xet==1.2.0 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
34
+ httpcore==1.0.9
35
+ httpx==0.28.1
36
+ httpx-sse==0.4.3
37
+ huggingface-hub==1.3.4
38
+ idna==3.11
39
+ importlib-metadata==8.7.1
40
+ ipython==8.38.0 ; python_full_version < '3.11'
41
+ ipython==9.9.0 ; python_full_version >= '3.11'
42
+ ipython-pygments-lexers==1.1.1 ; python_full_version >= '3.11'
43
+ jedi==0.19.2
44
+ jinja2==3.1.6
45
+ jmespath==1.0.1
46
+ jsonschema==4.26.0
47
+ jsonschema-specifications==2025.9.1
48
+ markdown-it-py==4.0.0
49
+ markupsafe==3.0.3
50
+ matplotlib-inline==0.2.1
51
+ mcp==1.26.0
52
+ mdurl==0.1.2
53
+ mpmath==1.3.0
54
+ msgpack==1.1.2
55
+ networkx==3.4.2 ; python_full_version < '3.11'
56
+ networkx==3.6.1 ; python_full_version >= '3.11'
57
+ numpy==2.2.6 ; python_full_version < '3.11'
58
+ numpy==2.4.1 ; python_full_version >= '3.11'
59
+ nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
60
+ nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
61
+ nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
62
+ nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
63
+ nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
64
+ nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
65
+ nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
66
+ nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
67
+ nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
68
+ nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
69
+ nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
70
+ nvidia-nccl-cu12==2.27.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
71
+ nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
72
+ nvidia-nvshmem-cu12==3.4.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
73
+ nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
74
+ orjson==3.11.5
75
+ packaging==26.0
76
+ pandas==2.3.3
77
+ parso==0.8.5
78
+ peft==0.18.1
79
+ pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
80
+ pillow==12.1.0
81
+ prompt-toolkit==3.0.52
82
+ psutil==5.9.8
83
+ ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
84
+ pure-eval==0.2.3
85
+ pycparser==3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
86
+ pydantic==2.12.5
87
+ pydantic-core==2.41.5
88
+ pydantic-settings==2.12.0
89
+ pydub==0.25.1
90
+ pygments==2.19.2
91
+ pyjwt==2.10.1
92
+ python-dateutil==2.9.0.post0
93
+ python-dotenv==1.2.1
94
+ python-multipart==0.0.22
95
+ pytz==2025.2
96
+ pywin32==311 ; sys_platform == 'win32'
97
+ pyyaml==6.0.3
98
+ referencing==0.37.0
99
+ regex==2026.1.15
100
+ requests==2.32.5
101
+ rich==14.3.1
102
+ rpds-py==0.30.0
103
+ s3transfer==0.16.0
104
+ safehttpx==0.1.7
105
+ safetensors==0.7.0
106
+ semantic-version==2.10.0
107
+ setuptools==80.10.2 ; python_full_version >= '3.12'
108
+ shellingham==1.5.4
109
+ six==1.17.0
110
+ spaces==0.47.0
111
+ sse-starlette==3.2.0
112
+ stack-data==0.6.3
113
+ starlette==0.50.0
114
+ sympy==1.14.0
115
+ tokenizers==0.22.2
116
+ tomlkit==0.13.3
117
+ torch==2.10.0
118
+ torchvision==0.25.0
119
+ tqdm==4.67.1
120
+ traitlets==5.14.3
121
+ transformers==5.0.0
122
+ triton==3.6.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
123
+ typer==0.21.1
124
+ typer-slim==0.21.1
125
+ typing-extensions==4.15.0
126
+ typing-inspection==0.4.2
127
+ tzdata==2025.3
128
+ ujson==5.11.0
129
+ urllib3==2.6.3
130
+ uvicorn==0.40.0
131
+ wcwidth==0.2.14
132
+ websockets==16.0
133
+ zipp==3.23.0
utils.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Camera angle data structures for Fibo Edit."""
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+
6
+
7
+ class View(Enum):
8
+ """Camera view angles"""
9
+ BACK_VIEW = "back view"
10
+ BACK_LEFT_QUARTER = "back-left quarter view"
11
+ BACK_RIGHT_QUARTER = "back-right quarter view"
12
+ FRONT_VIEW = "front view"
13
+ FRONT_LEFT_QUARTER = "front-left quarter view"
14
+ FRONT_RIGHT_QUARTER = "front-right quarter view"
15
+ LEFT_SIDE = "left side view"
16
+ RIGHT_SIDE = "right side view"
17
+
18
+
19
+ class Shot(Enum):
20
+ """
21
+ Camera shot angles (measured from horizontal/eye-level as 0 degrees)
22
+
23
+ - ELEVATED: 45-60 degrees above subject (moderately elevated)
24
+ - EYE_LEVEL: 0 degrees (horizontal with subject)
25
+ - HIGH_ANGLE: 60-90 degrees above subject (steep overhead, bird's eye)
26
+ - LOW_ANGLE: Below eye level (looking up at subject)
27
+ """
28
+ ELEVATED = "elevated shot"
29
+ EYE_LEVEL = "eye-level shot"
30
+ HIGH_ANGLE = "high-angle shot"
31
+ LOW_ANGLE = "low-angle shot"
32
+
33
+
34
+ class Zoom(Enum):
35
+ """Camera zoom levels"""
36
+ CLOSE_UP = "close-up"
37
+ MEDIUM = "medium shot"
38
+ WIDE = "wide shot"
39
+
40
+
41
+ @dataclass
42
+ class AngleInstruction:
43
+ view: View
44
+ shot: Shot
45
+ zoom: Zoom
46
+
47
+ def __str__(self):
48
+ return f"<sks> {self.view.value} {self.shot.value} {self.zoom.value}"
49
+
50
+ @classmethod
51
+ def from_camera_params(cls, rotation: float, tilt: float, zoom: float) -> "AngleInstruction":
52
+ """
53
+ Create an AngleInstruction from camera parameters.
54
+
55
+ Args:
56
+ rotation: Horizontal rotation in degrees (-180 to 180)
57
+ -180/180: back view, -90: left view, 0: front view, 90: right view
58
+ tilt: Vertical tilt (-1 to 1)
59
+ -1 to -0.33: low-angle shot
60
+ -0.33 to 0.33: eye-level shot
61
+ 0.33 to 0.66: elevated shot
62
+ 0.66 to 1: high-angle shot
63
+ zoom: Zoom level (0 to 10)
64
+ 0-3.33: wide shot
65
+ 3.33-6.66: medium shot
66
+ 6.66-10: close-up
67
+
68
+ Returns:
69
+ AngleInstruction instance
70
+ """
71
+ # Map rotation to View
72
+ # Normalize rotation to -180 to 180 range
73
+ rotation = rotation % 360
74
+ if rotation > 180:
75
+ rotation -= 360
76
+
77
+ # Determine view based on rotation
78
+ if -157.5 <= rotation < -112.5:
79
+ view = View.BACK_LEFT_QUARTER
80
+ elif -112.5 <= rotation < -67.5:
81
+ view = View.LEFT_SIDE
82
+ elif -67.5 <= rotation < -22.5:
83
+ view = View.FRONT_LEFT_QUARTER
84
+ elif -22.5 <= rotation < 22.5:
85
+ view = View.FRONT_VIEW
86
+ elif 22.5 <= rotation < 67.5:
87
+ view = View.FRONT_RIGHT_QUARTER
88
+ elif 67.5 <= rotation < 112.5:
89
+ view = View.RIGHT_SIDE
90
+ elif 112.5 <= rotation < 157.5:
91
+ view = View.BACK_RIGHT_QUARTER
92
+ else: # 157.5 to 180 or -180 to -157.5
93
+ view = View.BACK_VIEW
94
+
95
+ # Map tilt to Shot
96
+ if tilt < -0.33:
97
+ shot = Shot.LOW_ANGLE
98
+ elif tilt < 0.33:
99
+ shot = Shot.EYE_LEVEL
100
+ elif tilt < 0.66:
101
+ shot = Shot.ELEVATED
102
+ else:
103
+ shot = Shot.HIGH_ANGLE
104
+
105
+ # Map zoom to Zoom
106
+ if zoom < 3.33:
107
+ zoom_level = Zoom.WIDE
108
+ elif zoom < 6.66:
109
+ zoom_level = Zoom.MEDIUM
110
+ else:
111
+ zoom_level = Zoom.CLOSE_UP
112
+
113
+ return cls(view=view, shot=shot, zoom=zoom_level)