Spaces:
Running on Zero
Running on Zero
derek tingle commited on
Commit ·
6062b47
1
Parent(s): 56bbe8e
Initial commit
Browse files- README.md +52 -13
- app.py +1043 -0
- fibo_edit_pipeline.py +953 -0
- requirements.txt +133 -0
- utils.py +113 -0
README.md
CHANGED
|
@@ -1,13 +1,52 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
--
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fibo Edit — Camera Angle Control
|
| 2 |
+
|
| 3 |
+
Fibo Edit with Multi-Angle LoRA for precise camera control. Control rotation, tilt, and zoom to generate images from any angle.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- 🎬 Interactive 3D camera control widget
|
| 8 |
+
- 🎨 Multi-angle image generation using Fibo Edit model
|
| 9 |
+
- 📐 Precise control over rotation, tilt, and zoom
|
| 10 |
+
- 🤖 BRIA API integration for structured captions
|
| 11 |
+
- ⚡ GPU-accelerated inference with Spaces GPU
|
| 12 |
+
|
| 13 |
+
## Setup
|
| 14 |
+
|
| 15 |
+
### Required Secrets
|
| 16 |
+
|
| 17 |
+
This Space requires the following environment variable to be set as a **HuggingFace Space Secret**:
|
| 18 |
+
|
| 19 |
+
- `BRIA_API_TOKEN` - Your BRIA API token for structured caption generation
|
| 20 |
+
|
| 21 |
+
To add this secret:
|
| 22 |
+
1. Go to your Space's Settings
|
| 23 |
+
2. Navigate to "Repository secrets"
|
| 24 |
+
3. Add a new secret named `BRIA_API_TOKEN` with your API token value
|
| 25 |
+
|
| 26 |
+
### Hardware Requirements
|
| 27 |
+
|
| 28 |
+
This Space requires a GPU to run. Make sure to configure your Space to use a GPU instance.
|
| 29 |
+
|
| 30 |
+
## Usage
|
| 31 |
+
|
| 32 |
+
1. Upload an input image
|
| 33 |
+
2. Use the 3D camera control or sliders to adjust:
|
| 34 |
+
- **Rotation**: -180° (back) to +180° (back)
|
| 35 |
+
- **Vertical Tilt**: -1 (low angle) to +1 (high angle)
|
| 36 |
+
- **Zoom**: 0 (wide) to 10 (close-up)
|
| 37 |
+
3. Click "Generate" to create the image from the new camera angle
|
| 38 |
+
4. View the structured caption from BRIA API in the accordion
|
| 39 |
+
|
| 40 |
+
## Model Information
|
| 41 |
+
|
| 42 |
+
- **Base Model**: [briaai/FIBO-Edit](https://huggingface.co/briaai/FIBO-Edit)
|
| 43 |
+
- **LoRA**: [briaai/fibo_edit_multi_angle_full_0121_full_1k](https://huggingface.co/briaai/fibo_edit_multi_angle_full_0121_full_1k)
|
| 44 |
+
- **Text Encoder**: SmolLM3
|
| 45 |
+
- **Scheduler**: FlowMatchEulerDiscreteScheduler
|
| 46 |
+
|
| 47 |
+
## Credits
|
| 48 |
+
|
| 49 |
+
Built with:
|
| 50 |
+
- [Gradio](https://gradio.app/)
|
| 51 |
+
- [Diffusers](https://huggingface.co/docs/diffusers)
|
| 52 |
+
- [BRIA AI](https://bria.ai/)
|
app.py
ADDED
|
@@ -0,0 +1,1043 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import time
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import numpy as np
|
| 11 |
+
import requests
|
| 12 |
+
import spaces
|
| 13 |
+
import torch
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
from fibo_edit_pipeline import BriaFiboEditPipeline
|
| 17 |
+
from utils import AngleInstruction
|
| 18 |
+
|
| 19 |
+
# --- Configuration ---
|
| 20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
|
| 22 |
+
# Run locally or on HuggingFace Spaces
|
| 23 |
+
RUN_LOCAL = True
|
| 24 |
+
|
| 25 |
+
# Model paths
|
| 26 |
+
BASE_CHECKPOINT = "briaai/FIBO-Edit" # HuggingFace model ID
|
| 27 |
+
LORA_CHECKPOINT = "briaai/fibo_edit_multi_angle_full_0121_full_1k" # HuggingFace LoRA model ID
|
| 28 |
+
|
| 29 |
+
# BRIA API configuration
|
| 30 |
+
BRIA_API_URL = "https://engine.prod.bria-api.com/v2/structured_prompt/generate/pro"
|
| 31 |
+
BRIA_API_TOKEN = os.environ.get("BRIA_API_TOKEN")
|
| 32 |
+
|
| 33 |
+
if not BRIA_API_TOKEN:
|
| 34 |
+
raise ValueError(
|
| 35 |
+
"BRIA_API_TOKEN environment variable is not set. "
|
| 36 |
+
"Please add it as a HuggingFace Space secret."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Generation defaults
|
| 40 |
+
DEFAULT_NUM_INFERENCE_STEPS = 50
|
| 41 |
+
DEFAULT_GUIDANCE_SCALE = 3.5
|
| 42 |
+
DEFAULT_SEED = 100050
|
| 43 |
+
|
| 44 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 45 |
+
|
| 46 |
+
print("🚀 Starting Fibo Edit Multi-Angle LoRA Gradio App")
|
| 47 |
+
print(f"Device: {device}")
|
| 48 |
+
print(f"Base checkpoint: {BASE_CHECKPOINT}")
|
| 49 |
+
print(f"LoRA checkpoint: {LORA_CHECKPOINT}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# --- Helper Functions ---
|
| 53 |
+
def load_pipeline_fiboedit(
|
| 54 |
+
checkpoint: str,
|
| 55 |
+
lora_checkpoint: Optional[str] = None,
|
| 56 |
+
lora_scale: Optional[float] = None,
|
| 57 |
+
fuse_lora: bool = True,
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
Load the Fibo Edit pipeline using BriaFiboEditPipeline with optional LoRA weights.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
checkpoint: HuggingFace model ID for base model
|
| 64 |
+
lora_checkpoint: Optional HuggingFace model ID for LoRA weights
|
| 65 |
+
lora_scale: Scale for LoRA weights when fusing (default None = 1.0)
|
| 66 |
+
fuse_lora: Whether to fuse LoRA into base weights (default True)
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Loaded BriaFiboEditPipeline
|
| 70 |
+
"""
|
| 71 |
+
print(f"Loading BriaFiboEditPipeline from {checkpoint}")
|
| 72 |
+
if lora_checkpoint:
|
| 73 |
+
print(f" with LoRA from {lora_checkpoint}")
|
| 74 |
+
|
| 75 |
+
# Load pipeline from HuggingFace
|
| 76 |
+
print("Loading pipeline...")
|
| 77 |
+
pipe = BriaFiboEditPipeline.from_pretrained(
|
| 78 |
+
checkpoint,
|
| 79 |
+
torch_dtype=torch.bfloat16,
|
| 80 |
+
)
|
| 81 |
+
pipe.to("cuda")
|
| 82 |
+
print(f" Pipeline loaded from {checkpoint}")
|
| 83 |
+
|
| 84 |
+
# Load LoRA weights if provided (PEFT format)
|
| 85 |
+
if lora_checkpoint:
|
| 86 |
+
print(f"Loading PEFT LoRA from {lora_checkpoint}...")
|
| 87 |
+
from peft import PeftModel
|
| 88 |
+
|
| 89 |
+
print(" Loading PEFT adapter onto transformer...")
|
| 90 |
+
pipe.transformer = PeftModel.from_pretrained(pipe.transformer, lora_checkpoint)
|
| 91 |
+
print(" PEFT adapter loaded successfully")
|
| 92 |
+
|
| 93 |
+
if fuse_lora:
|
| 94 |
+
print(" Merging LoRA into base weights...")
|
| 95 |
+
if hasattr(pipe.transformer, "merge_and_unload"):
|
| 96 |
+
pipe.transformer = pipe.transformer.merge_and_unload()
|
| 97 |
+
print(" LoRA merged and unloaded")
|
| 98 |
+
else:
|
| 99 |
+
print(" [WARN] transformer.merge_and_unload() not available")
|
| 100 |
+
|
| 101 |
+
print("✅ Pipeline loaded successfully!")
|
| 102 |
+
return pipe
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def generate_structured_caption(
|
| 106 |
+
image: Image.Image, prompt: str, seed: int = 1
|
| 107 |
+
) -> Optional[dict]:
|
| 108 |
+
"""Generate structured caption using BRIA API."""
|
| 109 |
+
buffered = BytesIO()
|
| 110 |
+
image.save(buffered, format="PNG")
|
| 111 |
+
image_bytes = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 112 |
+
|
| 113 |
+
payload = {
|
| 114 |
+
"seed": seed,
|
| 115 |
+
"sync": True,
|
| 116 |
+
"images": [image_bytes],
|
| 117 |
+
"prompt": prompt,
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
headers = {
|
| 121 |
+
"Content-Type": "application/json",
|
| 122 |
+
"api_token": BRIA_API_TOKEN,
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
max_retries = 3
|
| 126 |
+
for attempt in range(max_retries):
|
| 127 |
+
try:
|
| 128 |
+
response = requests.post(
|
| 129 |
+
BRIA_API_URL, json=payload, headers=headers, timeout=60
|
| 130 |
+
)
|
| 131 |
+
response.raise_for_status()
|
| 132 |
+
data = response.json()
|
| 133 |
+
structured_prompt_str = data["result"]["structured_prompt"]
|
| 134 |
+
return json.loads(structured_prompt_str)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
if attempt == max_retries - 1:
|
| 137 |
+
print(f"Failed to generate structured caption: {e}")
|
| 138 |
+
return None
|
| 139 |
+
time.sleep(3)
|
| 140 |
+
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# --- Model Loading ---
|
| 145 |
+
print("Loading Fibo Edit pipeline...")
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
pipe = load_pipeline_fiboedit(
|
| 149 |
+
checkpoint=BASE_CHECKPOINT,
|
| 150 |
+
lora_checkpoint=LORA_CHECKPOINT,
|
| 151 |
+
lora_scale=None,
|
| 152 |
+
fuse_lora=True,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
if torch.cuda.is_available():
|
| 156 |
+
mem_allocated = torch.cuda.memory_allocated(0) / 1024**3
|
| 157 |
+
print(f" GPU memory allocated: {mem_allocated:.2f} GB")
|
| 158 |
+
|
| 159 |
+
except Exception as e:
|
| 160 |
+
print(f"❌ Error loading pipeline: {e}")
|
| 161 |
+
import traceback
|
| 162 |
+
|
| 163 |
+
traceback.print_exc()
|
| 164 |
+
raise
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def build_camera_prompt(
|
| 168 |
+
rotate_deg: float = 0.0, zoom: float = 0.0, vertical_tilt: float = 0.0
|
| 169 |
+
) -> str:
|
| 170 |
+
"""Build a natural language camera instruction from parameters."""
|
| 171 |
+
# Create AngleInstruction from camera parameters
|
| 172 |
+
angle_instruction = AngleInstruction.from_camera_params(
|
| 173 |
+
rotation=rotate_deg, tilt=vertical_tilt, zoom=zoom
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Generate natural language description
|
| 177 |
+
view_map = {
|
| 178 |
+
"back view": "view from the opposite side",
|
| 179 |
+
"back-left quarter view": "rotate 135 degrees left",
|
| 180 |
+
"back-right quarter view": "rotate 135 degrees right",
|
| 181 |
+
"front view": "keep the front view",
|
| 182 |
+
"front-left quarter view": "rotate 45 degrees left",
|
| 183 |
+
"front-right quarter view": "rotate 45 degrees right",
|
| 184 |
+
"left side view": "rotate 90 degrees left",
|
| 185 |
+
"right side view": "rotate 90 degrees right",
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
shot_map = {
|
| 189 |
+
"elevated shot": "with an elevated viewing angle",
|
| 190 |
+
"eye-level shot": "with an eye-level viewing angle",
|
| 191 |
+
"high-angle shot": "with a high-angle viewing angle",
|
| 192 |
+
"low-angle shot": "with a low-angle viewing angle",
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
zoom_map = {
|
| 196 |
+
"close-up": "and make it a close-up shot",
|
| 197 |
+
"medium shot": "", # Omit medium shot
|
| 198 |
+
"wide shot": "and make it a wide shot",
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
view_text = view_map[angle_instruction.view.value]
|
| 202 |
+
shot_text = shot_map[angle_instruction.shot.value]
|
| 203 |
+
zoom_text = zoom_map[angle_instruction.zoom.value]
|
| 204 |
+
|
| 205 |
+
# Construct the natural language prompt starting with "Change the viewing angle"
|
| 206 |
+
parts = [view_text, shot_text]
|
| 207 |
+
if zoom_text: # Only add zoom if not empty (medium shot is omitted)
|
| 208 |
+
parts.append(zoom_text)
|
| 209 |
+
natural_prompt = "Change the viewing angle: " + ", ".join(parts)
|
| 210 |
+
|
| 211 |
+
return natural_prompt, angle_instruction
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def fetch_structured_caption(
|
| 215 |
+
image: Optional[Image.Image] = None,
|
| 216 |
+
rotate_deg: float = 0.0,
|
| 217 |
+
zoom: float = 0.0,
|
| 218 |
+
vertical_tilt: float = 0.0,
|
| 219 |
+
seed: int = 0,
|
| 220 |
+
randomize_seed: bool = True,
|
| 221 |
+
prev_output: Optional[Image.Image] = None,
|
| 222 |
+
) -> Tuple[int, str, dict, Image.Image]:
|
| 223 |
+
"""Fetch structured caption from BRIA API."""
|
| 224 |
+
|
| 225 |
+
# Build natural language prompt and angle instruction
|
| 226 |
+
natural_prompt, angle_instruction = build_camera_prompt(
|
| 227 |
+
rotate_deg, zoom, vertical_tilt
|
| 228 |
+
)
|
| 229 |
+
print(f"Natural Language Prompt: {natural_prompt}")
|
| 230 |
+
print(f"Angle Instruction: {str(angle_instruction)}")
|
| 231 |
+
|
| 232 |
+
if randomize_seed:
|
| 233 |
+
seed = random.randint(0, MAX_SEED)
|
| 234 |
+
|
| 235 |
+
# Get input image
|
| 236 |
+
if image is not None:
|
| 237 |
+
if isinstance(image, Image.Image):
|
| 238 |
+
input_image = image.convert("RGB")
|
| 239 |
+
elif hasattr(image, "name"):
|
| 240 |
+
input_image = Image.open(image.name).convert("RGB")
|
| 241 |
+
else:
|
| 242 |
+
input_image = image
|
| 243 |
+
elif prev_output:
|
| 244 |
+
input_image = prev_output.convert("RGB")
|
| 245 |
+
else:
|
| 246 |
+
raise gr.Error("Please upload an image first.")
|
| 247 |
+
|
| 248 |
+
# Generate structured caption using BRIA API
|
| 249 |
+
print("Generating structured caption from BRIA API...")
|
| 250 |
+
structured_caption = generate_structured_caption(
|
| 251 |
+
input_image, natural_prompt, seed=seed
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if structured_caption is None:
|
| 255 |
+
raise gr.Error("Failed to generate structured caption from BRIA API")
|
| 256 |
+
|
| 257 |
+
# Replace edit_instruction with angle instruction string
|
| 258 |
+
structured_caption["edit_instruction"] = str(angle_instruction)
|
| 259 |
+
|
| 260 |
+
print(
|
| 261 |
+
f"Structured caption received: {json.dumps(structured_caption, ensure_ascii=False)}"
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
return seed, natural_prompt, structured_caption, input_image
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
@spaces.GPU
|
| 268 |
+
def generate_image_from_caption(
|
| 269 |
+
input_image: Image.Image,
|
| 270 |
+
structured_caption: dict,
|
| 271 |
+
seed: int,
|
| 272 |
+
guidance_scale: float = 3.5,
|
| 273 |
+
num_inference_steps: int = 50,
|
| 274 |
+
) -> Image.Image:
|
| 275 |
+
"""Generate image using Fibo Edit pipeline with structured caption."""
|
| 276 |
+
|
| 277 |
+
structured_prompt = json.dumps(structured_caption, ensure_ascii=False)
|
| 278 |
+
print("Generating image with structured prompt...")
|
| 279 |
+
|
| 280 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 281 |
+
|
| 282 |
+
result = pipe(
|
| 283 |
+
image=input_image,
|
| 284 |
+
prompt=structured_prompt,
|
| 285 |
+
guidance_scale=guidance_scale,
|
| 286 |
+
num_inference_steps=num_inference_steps,
|
| 287 |
+
generator=generator,
|
| 288 |
+
num_images_per_prompt=1,
|
| 289 |
+
).images[0]
|
| 290 |
+
|
| 291 |
+
return result
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# --- 3D Camera Control Component ---
|
| 295 |
+
# Using gr.HTML directly with templates (Gradio 6 style)
|
| 296 |
+
|
| 297 |
+
CAMERA_3D_HTML_TEMPLATE = """
|
| 298 |
+
<div id="camera-control-wrapper" style="width: 100%; height: 400px; position: relative; background: #1a1a1a; border-radius: 12px; overflow: hidden;">
|
| 299 |
+
<div id="prompt-overlay" style="position: absolute; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.8); padding: 8px 16px; border-radius: 8px; font-family: monospace; font-size: 11px; color: #00ff88; white-space: nowrap; z-index: 10; max-width: 90%; overflow: hidden; text-overflow: ellipsis;"></div>
|
| 300 |
+
<div id="control-legend" style="position: absolute; top: 10px; left: 10px; background: rgba(0,0,0,0.7); padding: 8px 12px; border-radius: 8px; font-family: system-ui; font-size: 11px; color: #fff; z-index: 10;">
|
| 301 |
+
<div style="margin-bottom: 4px;"><span style="color: #00ff88;">●</span> Rotation (↔)</div>
|
| 302 |
+
<div style="margin-bottom: 4px;"><span style="color: #ff69b4;">●</span> Vertical Tilt (↕)</div>
|
| 303 |
+
<div><span style="color: #ffa500;">●</span> Distance/Zoom</div>
|
| 304 |
+
</div>
|
| 305 |
+
</div>
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
CAMERA_3D_JS = """
|
| 309 |
+
(() => {
|
| 310 |
+
const wrapper = element.querySelector('#camera-control-wrapper');
|
| 311 |
+
const promptOverlay = element.querySelector('#prompt-overlay');
|
| 312 |
+
|
| 313 |
+
const initScene = () => {
|
| 314 |
+
if (typeof THREE === 'undefined') {
|
| 315 |
+
setTimeout(initScene, 100);
|
| 316 |
+
return;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
const scene = new THREE.Scene();
|
| 320 |
+
scene.background = new THREE.Color(0x1a1a1a);
|
| 321 |
+
|
| 322 |
+
const camera = new THREE.PerspectiveCamera(50, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
|
| 323 |
+
camera.position.set(4, 3, 4);
|
| 324 |
+
camera.lookAt(0, 0.75, 0);
|
| 325 |
+
|
| 326 |
+
const renderer = new THREE.WebGLRenderer({ antialias: true });
|
| 327 |
+
renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
|
| 328 |
+
renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
|
| 329 |
+
wrapper.insertBefore(renderer.domElement, wrapper.firstChild);
|
| 330 |
+
|
| 331 |
+
scene.add(new THREE.AmbientLight(0xffffff, 0.6));
|
| 332 |
+
const dirLight = new THREE.DirectionalLight(0xffffff, 0.6);
|
| 333 |
+
dirLight.position.set(5, 10, 5);
|
| 334 |
+
scene.add(dirLight);
|
| 335 |
+
|
| 336 |
+
scene.add(new THREE.GridHelper(6, 12, 0x333333, 0x222222));
|
| 337 |
+
|
| 338 |
+
const CENTER = new THREE.Vector3(0, 0.75, 0);
|
| 339 |
+
const BASE_DISTANCE = 2.0;
|
| 340 |
+
const ROTATION_RADIUS = 2.2;
|
| 341 |
+
const TILT_RADIUS = 1.6;
|
| 342 |
+
|
| 343 |
+
let rotateDeg = props.value?.rotate_deg || 0;
|
| 344 |
+
let zoom = props.value?.zoom || 5.0;
|
| 345 |
+
let verticalTilt = props.value?.vertical_tilt || 0;
|
| 346 |
+
|
| 347 |
+
const rotateSteps = [-180, -135, -90, -45, 0, 45, 90, 135, 180];
|
| 348 |
+
const zoomSteps = [0, 5, 10];
|
| 349 |
+
const tiltSteps = [-1, -0.5, 0, 0.5, 1];
|
| 350 |
+
|
| 351 |
+
function snapToNearest(value, steps) {
|
| 352 |
+
return steps.reduce((prev, curr) => Math.abs(curr - value) < Math.abs(prev - value) ? curr : prev);
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
function createPlaceholderTexture() {
|
| 356 |
+
const canvas = document.createElement('canvas');
|
| 357 |
+
canvas.width = 256;
|
| 358 |
+
canvas.height = 256;
|
| 359 |
+
const ctx = canvas.getContext('2d');
|
| 360 |
+
ctx.fillStyle = '#3a3a4a';
|
| 361 |
+
ctx.fillRect(0, 0, 256, 256);
|
| 362 |
+
ctx.fillStyle = '#ffcc99';
|
| 363 |
+
ctx.beginPath();
|
| 364 |
+
ctx.arc(128, 128, 80, 0, Math.PI * 2);
|
| 365 |
+
ctx.fill();
|
| 366 |
+
ctx.fillStyle = '#333';
|
| 367 |
+
ctx.beginPath();
|
| 368 |
+
ctx.arc(100, 110, 10, 0, Math.PI * 2);
|
| 369 |
+
ctx.arc(156, 110, 10, 0, Math.PI * 2);
|
| 370 |
+
ctx.fill();
|
| 371 |
+
ctx.strokeStyle = '#333';
|
| 372 |
+
ctx.lineWidth = 3;
|
| 373 |
+
ctx.beginPath();
|
| 374 |
+
ctx.arc(128, 130, 35, 0.2, Math.PI - 0.2);
|
| 375 |
+
ctx.stroke();
|
| 376 |
+
return new THREE.CanvasTexture(canvas);
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
let currentTexture = createPlaceholderTexture();
|
| 380 |
+
const planeMaterial = new THREE.MeshBasicMaterial({ map: currentTexture, side: THREE.DoubleSide });
|
| 381 |
+
let targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
|
| 382 |
+
targetPlane.position.copy(CENTER);
|
| 383 |
+
scene.add(targetPlane);
|
| 384 |
+
|
| 385 |
+
function updateTextureFromUrl(url) {
|
| 386 |
+
if (!url) {
|
| 387 |
+
planeMaterial.map = createPlaceholderTexture();
|
| 388 |
+
planeMaterial.needsUpdate = true;
|
| 389 |
+
scene.remove(targetPlane);
|
| 390 |
+
targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
|
| 391 |
+
targetPlane.position.copy(CENTER);
|
| 392 |
+
scene.add(targetPlane);
|
| 393 |
+
return;
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
const loader = new THREE.TextureLoader();
|
| 397 |
+
loader.crossOrigin = 'anonymous';
|
| 398 |
+
loader.load(url, (texture) => {
|
| 399 |
+
texture.minFilter = THREE.LinearFilter;
|
| 400 |
+
texture.magFilter = THREE.LinearFilter;
|
| 401 |
+
planeMaterial.map = texture;
|
| 402 |
+
planeMaterial.needsUpdate = true;
|
| 403 |
+
|
| 404 |
+
const img = texture.image;
|
| 405 |
+
if (img && img.width && img.height) {
|
| 406 |
+
const aspect = img.width / img.height;
|
| 407 |
+
const maxSize = 1.4;
|
| 408 |
+
let planeWidth, planeHeight;
|
| 409 |
+
if (aspect > 1) {
|
| 410 |
+
planeWidth = maxSize;
|
| 411 |
+
planeHeight = maxSize / aspect;
|
| 412 |
+
} else {
|
| 413 |
+
planeHeight = maxSize;
|
| 414 |
+
planeWidth = maxSize * aspect;
|
| 415 |
+
}
|
| 416 |
+
scene.remove(targetPlane);
|
| 417 |
+
targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(planeWidth, planeHeight), planeMaterial);
|
| 418 |
+
targetPlane.position.copy(CENTER);
|
| 419 |
+
scene.add(targetPlane);
|
| 420 |
+
}
|
| 421 |
+
});
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
if (props.imageUrl) {
|
| 425 |
+
updateTextureFromUrl(props.imageUrl);
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
const cameraGroup = new THREE.Group();
|
| 429 |
+
const bodyMat = new THREE.MeshStandardMaterial({ color: 0x6699cc, metalness: 0.5, roughness: 0.3 });
|
| 430 |
+
const body = new THREE.Mesh(new THREE.BoxGeometry(0.28, 0.2, 0.35), bodyMat);
|
| 431 |
+
cameraGroup.add(body);
|
| 432 |
+
const lens = new THREE.Mesh(
|
| 433 |
+
new THREE.CylinderGeometry(0.08, 0.1, 0.16, 16),
|
| 434 |
+
new THREE.MeshStandardMaterial({ color: 0x6699cc, metalness: 0.5, roughness: 0.3 })
|
| 435 |
+
);
|
| 436 |
+
lens.rotation.x = Math.PI / 2;
|
| 437 |
+
lens.position.z = 0.24;
|
| 438 |
+
cameraGroup.add(lens);
|
| 439 |
+
scene.add(cameraGroup);
|
| 440 |
+
|
| 441 |
+
const rotationArcPoints = [];
|
| 442 |
+
for (let i = 0; i <= 64; i++) {
|
| 443 |
+
const angle = THREE.MathUtils.degToRad((360 * i / 64));
|
| 444 |
+
rotationArcPoints.push(new THREE.Vector3(ROTATION_RADIUS * Math.sin(angle), 0.05, ROTATION_RADIUS * Math.cos(angle)));
|
| 445 |
+
}
|
| 446 |
+
const rotationCurve = new THREE.CatmullRomCurve3(rotationArcPoints);
|
| 447 |
+
const rotationArc = new THREE.Mesh(
|
| 448 |
+
new THREE.TubeGeometry(rotationCurve, 64, 0.035, 8, true),
|
| 449 |
+
new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.3 })
|
| 450 |
+
);
|
| 451 |
+
scene.add(rotationArc);
|
| 452 |
+
|
| 453 |
+
const rotationHandle = new THREE.Mesh(
|
| 454 |
+
new THREE.SphereGeometry(0.16, 16, 16),
|
| 455 |
+
new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.5 })
|
| 456 |
+
);
|
| 457 |
+
rotationHandle.userData.type = 'rotation';
|
| 458 |
+
scene.add(rotationHandle);
|
| 459 |
+
|
| 460 |
+
const tiltArcPoints = [];
|
| 461 |
+
for (let i = 0; i <= 32; i++) {
|
| 462 |
+
const angle = THREE.MathUtils.degToRad(-45 + (90 * i / 32));
|
| 463 |
+
tiltArcPoints.push(new THREE.Vector3(-0.7, TILT_RADIUS * Math.sin(angle) + CENTER.y, TILT_RADIUS * Math.cos(angle)));
|
| 464 |
+
}
|
| 465 |
+
const tiltCurve = new THREE.CatmullRomCurve3(tiltArcPoints);
|
| 466 |
+
const tiltArc = new THREE.Mesh(
|
| 467 |
+
new THREE.TubeGeometry(tiltCurve, 32, 0.035, 8, false),
|
| 468 |
+
new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.3 })
|
| 469 |
+
);
|
| 470 |
+
scene.add(tiltArc);
|
| 471 |
+
|
| 472 |
+
const tiltHandle = new THREE.Mesh(
|
| 473 |
+
new THREE.SphereGeometry(0.16, 16, 16),
|
| 474 |
+
new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.5 })
|
| 475 |
+
);
|
| 476 |
+
tiltHandle.userData.type = 'tilt';
|
| 477 |
+
scene.add(tiltHandle);
|
| 478 |
+
|
| 479 |
+
const distanceLineGeo = new THREE.BufferGeometry();
|
| 480 |
+
const distanceLine = new THREE.Line(distanceLineGeo, new THREE.LineBasicMaterial({ color: 0xffa500 }));
|
| 481 |
+
scene.add(distanceLine);
|
| 482 |
+
|
| 483 |
+
const distanceHandle = new THREE.Mesh(
|
| 484 |
+
new THREE.SphereGeometry(0.16, 16, 16),
|
| 485 |
+
new THREE.MeshStandardMaterial({ color: 0xffa500, emissive: 0xffa500, emissiveIntensity: 0.5 })
|
| 486 |
+
);
|
| 487 |
+
distanceHandle.userData.type = 'distance';
|
| 488 |
+
scene.add(distanceHandle);
|
| 489 |
+
|
| 490 |
+
function buildPromptText(rot, zoomVal, tilt) {
|
| 491 |
+
const parts = [];
|
| 492 |
+
if (rot !== 0) {
|
| 493 |
+
const dir = rot > 0 ? 'right' : 'left';
|
| 494 |
+
parts.push('Rotate ' + Math.abs(rot) + '° ' + dir);
|
| 495 |
+
}
|
| 496 |
+
if (zoomVal >= 6.66) parts.push('Close-up');
|
| 497 |
+
else if (zoomVal >= 3.33) parts.push('Medium shot');
|
| 498 |
+
else parts.push('Wide angle');
|
| 499 |
+
if (tilt >= 0.66) parts.push("High angle");
|
| 500 |
+
else if (tilt >= 0.33) parts.push("Elevated");
|
| 501 |
+
else if (tilt <= -0.33) parts.push("Low angle");
|
| 502 |
+
else parts.push("Eye level");
|
| 503 |
+
return parts.length > 0 ? parts.join(' • ') : 'No camera movement';
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
function updatePositions() {
|
| 507 |
+
const rotRad = THREE.MathUtils.degToRad(rotateDeg);
|
| 508 |
+
// Map zoom 0-10 to distance: zoom 0 = far (3.0), zoom 10 = close (1.0)
|
| 509 |
+
const distance = 3.0 - (zoom / 10) * 2.0;
|
| 510 |
+
const tiltAngle = verticalTilt * 35;
|
| 511 |
+
const tiltRad = THREE.MathUtils.degToRad(tiltAngle);
|
| 512 |
+
|
| 513 |
+
const camX = distance * Math.sin(rotRad) * Math.cos(tiltRad);
|
| 514 |
+
const camY = distance * Math.sin(tiltRad) + CENTER.y;
|
| 515 |
+
const camZ = distance * Math.cos(rotRad) * Math.cos(tiltRad);
|
| 516 |
+
|
| 517 |
+
cameraGroup.position.set(camX, camY, camZ);
|
| 518 |
+
cameraGroup.lookAt(CENTER);
|
| 519 |
+
|
| 520 |
+
rotationHandle.position.set(ROTATION_RADIUS * Math.sin(rotRad), 0.05, ROTATION_RADIUS * Math.cos(rotRad));
|
| 521 |
+
|
| 522 |
+
const tiltHandleAngle = THREE.MathUtils.degToRad(tiltAngle);
|
| 523 |
+
tiltHandle.position.set(-0.7, TILT_RADIUS * Math.sin(tiltHandleAngle) + CENTER.y, TILT_RADIUS * Math.cos(tiltHandleAngle));
|
| 524 |
+
|
| 525 |
+
const handleDist = distance - 0.4;
|
| 526 |
+
distanceHandle.position.set(
|
| 527 |
+
handleDist * Math.sin(rotRad) * Math.cos(tiltRad),
|
| 528 |
+
handleDist * Math.sin(tiltRad) + CENTER.y,
|
| 529 |
+
handleDist * Math.cos(rotRad) * Math.cos(tiltRad)
|
| 530 |
+
);
|
| 531 |
+
distanceLineGeo.setFromPoints([cameraGroup.position.clone(), CENTER.clone()]);
|
| 532 |
+
|
| 533 |
+
promptOverlay.textContent = buildPromptText(rotateDeg, zoom, verticalTilt);
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
function updatePropsAndTrigger() {
|
| 537 |
+
const rotSnap = snapToNearest(rotateDeg, rotateSteps);
|
| 538 |
+
const zoomSnap = snapToNearest(zoom, zoomSteps);
|
| 539 |
+
const tiltSnap = snapToNearest(verticalTilt, tiltSteps);
|
| 540 |
+
|
| 541 |
+
props.value = { rotate_deg: rotSnap, zoom: zoomSnap, vertical_tilt: tiltSnap };
|
| 542 |
+
trigger('change', props.value);
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
const raycaster = new THREE.Raycaster();
|
| 546 |
+
const mouse = new THREE.Vector2();
|
| 547 |
+
let isDragging = false;
|
| 548 |
+
let dragTarget = null;
|
| 549 |
+
let dragStartMouse = new THREE.Vector2();
|
| 550 |
+
let dragStartZoom = 0;
|
| 551 |
+
const intersection = new THREE.Vector3();
|
| 552 |
+
|
| 553 |
+
const canvas = renderer.domElement;
|
| 554 |
+
|
| 555 |
+
canvas.addEventListener('mousedown', (e) => {
|
| 556 |
+
const rect = canvas.getBoundingClientRect();
|
| 557 |
+
mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
|
| 558 |
+
mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
|
| 559 |
+
|
| 560 |
+
raycaster.setFromCamera(mouse, camera);
|
| 561 |
+
const intersects = raycaster.intersectObjects([rotationHandle, tiltHandle, distanceHandle]);
|
| 562 |
+
|
| 563 |
+
if (intersects.length > 0) {
|
| 564 |
+
isDragging = true;
|
| 565 |
+
dragTarget = intersects[0].object;
|
| 566 |
+
dragTarget.material.emissiveIntensity = 1.0;
|
| 567 |
+
dragTarget.scale.setScalar(1.3);
|
| 568 |
+
dragStartMouse.copy(mouse);
|
| 569 |
+
dragStartZoom = zoom;
|
| 570 |
+
canvas.style.cursor = 'grabbing';
|
| 571 |
+
}
|
| 572 |
+
});
|
| 573 |
+
|
| 574 |
+
canvas.addEventListener('mousemove', (e) => {
|
| 575 |
+
const rect = canvas.getBoundingClientRect();
|
| 576 |
+
mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
|
| 577 |
+
mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
|
| 578 |
+
|
| 579 |
+
if (isDragging && dragTarget) {
|
| 580 |
+
raycaster.setFromCamera(mouse, camera);
|
| 581 |
+
|
| 582 |
+
if (dragTarget.userData.type === 'rotation') {
|
| 583 |
+
const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
|
| 584 |
+
if (raycaster.ray.intersectPlane(plane, intersection)) {
|
| 585 |
+
let angle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
|
| 586 |
+
rotateDeg = THREE.MathUtils.clamp(angle, -180, 180);
|
| 587 |
+
}
|
| 588 |
+
} else if (dragTarget.userData.type === 'tilt') {
|
| 589 |
+
const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), 0.7);
|
| 590 |
+
if (raycaster.ray.intersectPlane(plane, intersection)) {
|
| 591 |
+
const relY = intersection.y - CENTER.y;
|
| 592 |
+
const relZ = intersection.z;
|
| 593 |
+
const angle = THREE.MathUtils.radToDeg(Math.atan2(relY, relZ));
|
| 594 |
+
verticalTilt = THREE.MathUtils.clamp(angle / 35, -1, 1);
|
| 595 |
+
}
|
| 596 |
+
} else if (dragTarget.userData.type === 'distance') {
|
| 597 |
+
const deltaY = mouse.y - dragStartMouse.y;
|
| 598 |
+
zoom = THREE.MathUtils.clamp(dragStartZoom + deltaY * 20, 0, 10);
|
| 599 |
+
}
|
| 600 |
+
updatePositions();
|
| 601 |
+
} else {
|
| 602 |
+
raycaster.setFromCamera(mouse, camera);
|
| 603 |
+
const intersects = raycaster.intersectObjects([rotationHandle, tiltHandle, distanceHandle]);
|
| 604 |
+
[rotationHandle, tiltHandle, distanceHandle].forEach(h => {
|
| 605 |
+
h.material.emissiveIntensity = 0.5;
|
| 606 |
+
h.scale.setScalar(1);
|
| 607 |
+
});
|
| 608 |
+
if (intersects.length > 0) {
|
| 609 |
+
intersects[0].object.material.emissiveIntensity = 0.8;
|
| 610 |
+
intersects[0].object.scale.setScalar(1.1);
|
| 611 |
+
canvas.style.cursor = 'grab';
|
| 612 |
+
} else {
|
| 613 |
+
canvas.style.cursor = 'default';
|
| 614 |
+
}
|
| 615 |
+
}
|
| 616 |
+
});
|
| 617 |
+
|
| 618 |
+
const onMouseUp = () => {
|
| 619 |
+
if (dragTarget) {
|
| 620 |
+
dragTarget.material.emissiveIntensity = 0.5;
|
| 621 |
+
dragTarget.scale.setScalar(1);
|
| 622 |
+
|
| 623 |
+
const targetRot = snapToNearest(rotateDeg, rotateSteps);
|
| 624 |
+
const targetZoom = snapToNearest(zoom, zoomSteps);
|
| 625 |
+
const targetTilt = snapToNearest(verticalTilt, tiltSteps);
|
| 626 |
+
|
| 627 |
+
const startRot = rotateDeg, startZoom = zoom, startTilt = verticalTilt;
|
| 628 |
+
const startTime = Date.now();
|
| 629 |
+
|
| 630 |
+
function animateSnap() {
|
| 631 |
+
const t = Math.min((Date.now() - startTime) / 200, 1);
|
| 632 |
+
const ease = 1 - Math.pow(1 - t, 3);
|
| 633 |
+
|
| 634 |
+
rotateDeg = startRot + (targetRot - startRot) * ease;
|
| 635 |
+
zoom = startZoom + (targetZoom - startZoom) * ease;
|
| 636 |
+
verticalTilt = startTilt + (targetTilt - startTilt) * ease;
|
| 637 |
+
|
| 638 |
+
updatePositions();
|
| 639 |
+
if (t < 1) requestAnimationFrame(animateSnap);
|
| 640 |
+
else updatePropsAndTrigger();
|
| 641 |
+
}
|
| 642 |
+
animateSnap();
|
| 643 |
+
}
|
| 644 |
+
isDragging = false;
|
| 645 |
+
dragTarget = null;
|
| 646 |
+
canvas.style.cursor = 'default';
|
| 647 |
+
};
|
| 648 |
+
|
| 649 |
+
canvas.addEventListener('mouseup', onMouseUp);
|
| 650 |
+
canvas.addEventListener('mouseleave', onMouseUp);
|
| 651 |
+
|
| 652 |
+
canvas.addEventListener('touchstart', (e) => {
|
| 653 |
+
e.preventDefault();
|
| 654 |
+
const touch = e.touches[0];
|
| 655 |
+
const rect = canvas.getBoundingClientRect();
|
| 656 |
+
mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
|
| 657 |
+
mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
|
| 658 |
+
|
| 659 |
+
raycaster.setFromCamera(mouse, camera);
|
| 660 |
+
const intersects = raycaster.intersectObjects([rotationHandle, tiltHandle, distanceHandle]);
|
| 661 |
+
|
| 662 |
+
if (intersects.length > 0) {
|
| 663 |
+
isDragging = true;
|
| 664 |
+
dragTarget = intersects[0].object;
|
| 665 |
+
dragTarget.material.emissiveIntensity = 1.0;
|
| 666 |
+
dragTarget.scale.setScalar(1.3);
|
| 667 |
+
dragStartMouse.copy(mouse);
|
| 668 |
+
dragStartZoom = zoom;
|
| 669 |
+
}
|
| 670 |
+
}, { passive: false });
|
| 671 |
+
|
| 672 |
+
canvas.addEventListener('touchmove', (e) => {
|
| 673 |
+
e.preventDefault();
|
| 674 |
+
const touch = e.touches[0];
|
| 675 |
+
const rect = canvas.getBoundingClientRect();
|
| 676 |
+
mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
|
| 677 |
+
mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
|
| 678 |
+
|
| 679 |
+
if (isDragging && dragTarget) {
|
| 680 |
+
raycaster.setFromCamera(mouse, camera);
|
| 681 |
+
|
| 682 |
+
if (dragTarget.userData.type === 'rotation') {
|
| 683 |
+
const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
|
| 684 |
+
if (raycaster.ray.intersectPlane(plane, intersection)) {
|
| 685 |
+
let angle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
|
| 686 |
+
rotateDeg = THREE.MathUtils.clamp(angle, -180, 180);
|
| 687 |
+
}
|
| 688 |
+
} else if (dragTarget.userData.type === 'tilt') {
|
| 689 |
+
const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), 0.7);
|
| 690 |
+
if (raycaster.ray.intersectPlane(plane, intersection)) {
|
| 691 |
+
const relY = intersection.y - CENTER.y;
|
| 692 |
+
const relZ = intersection.z;
|
| 693 |
+
const angle = THREE.MathUtils.radToDeg(Math.atan2(relY, relZ));
|
| 694 |
+
verticalTilt = THREE.MathUtils.clamp(angle / 35, -1, 1);
|
| 695 |
+
}
|
| 696 |
+
} else if (dragTarget.userData.type === 'distance') {
|
| 697 |
+
const deltaY = mouse.y - dragStartMouse.y;
|
| 698 |
+
zoom = THREE.MathUtils.clamp(dragStartZoom + deltaY * 20, 0, 10);
|
| 699 |
+
}
|
| 700 |
+
updatePositions();
|
| 701 |
+
}
|
| 702 |
+
}, { passive: false });
|
| 703 |
+
|
| 704 |
+
canvas.addEventListener('touchend', (e) => { e.preventDefault(); onMouseUp(); }, { passive: false });
|
| 705 |
+
canvas.addEventListener('touchcancel', (e) => { e.preventDefault(); onMouseUp(); }, { passive: false });
|
| 706 |
+
|
| 707 |
+
updatePositions();
|
| 708 |
+
|
| 709 |
+
function render() {
|
| 710 |
+
requestAnimationFrame(render);
|
| 711 |
+
renderer.render(scene, camera);
|
| 712 |
+
}
|
| 713 |
+
render();
|
| 714 |
+
|
| 715 |
+
new ResizeObserver(() => {
|
| 716 |
+
camera.aspect = wrapper.clientWidth / wrapper.clientHeight;
|
| 717 |
+
camera.updateProjectionMatrix();
|
| 718 |
+
renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
|
| 719 |
+
}).observe(wrapper);
|
| 720 |
+
|
| 721 |
+
wrapper._updateTexture = updateTextureFromUrl;
|
| 722 |
+
|
| 723 |
+
let lastImageUrl = props.imageUrl;
|
| 724 |
+
let lastValue = JSON.stringify(props.value);
|
| 725 |
+
setInterval(() => {
|
| 726 |
+
if (props.imageUrl !== lastImageUrl) {
|
| 727 |
+
lastImageUrl = props.imageUrl;
|
| 728 |
+
updateTextureFromUrl(props.imageUrl);
|
| 729 |
+
}
|
| 730 |
+
const currentValue = JSON.stringify(props.value);
|
| 731 |
+
if (currentValue !== lastValue) {
|
| 732 |
+
lastValue = currentValue;
|
| 733 |
+
if (props.value && typeof props.value === 'object') {
|
| 734 |
+
rotateDeg = props.value.rotate_deg ?? rotateDeg;
|
| 735 |
+
zoom = props.value.zoom ?? zoom;
|
| 736 |
+
verticalTilt = props.value.vertical_tilt ?? verticalTilt;
|
| 737 |
+
updatePositions();
|
| 738 |
+
}
|
| 739 |
+
}
|
| 740 |
+
}, 100);
|
| 741 |
+
};
|
| 742 |
+
|
| 743 |
+
initScene();
|
| 744 |
+
})();
|
| 745 |
+
"""
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
def create_camera_3d_component(value=None, imageUrl=None, **kwargs):
|
| 749 |
+
"""Create a 3D camera control component using gr.HTML."""
|
| 750 |
+
if value is None:
|
| 751 |
+
value = {"rotate_deg": 0, "zoom": 5.0, "vertical_tilt": 0}
|
| 752 |
+
|
| 753 |
+
return gr.HTML(
|
| 754 |
+
value=value,
|
| 755 |
+
html_template=CAMERA_3D_HTML_TEMPLATE,
|
| 756 |
+
js_on_load=CAMERA_3D_JS,
|
| 757 |
+
imageUrl=imageUrl,
|
| 758 |
+
**kwargs,
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
# --- UI ---
|
| 763 |
+
css = """
|
| 764 |
+
#col-container { max-width: 1100px; margin: 0 auto; }
|
| 765 |
+
.dark .progress-text { color: white !important; }
|
| 766 |
+
#camera-3d-control { min-height: 400px; }
|
| 767 |
+
#examples { max-width: 1100px; margin: 0 auto; }
|
| 768 |
+
.fillable{max-width: 1250px !important}
|
| 769 |
+
"""
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def reset_all() -> list:
|
| 773 |
+
"""Reset all camera control knobs and flags to their default values."""
|
| 774 |
+
return [0, 5.0, 0, True] # rotate_deg, zoom, vertical_tilt, is_reset
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
def end_reset() -> bool:
|
| 778 |
+
"""Mark the end of a reset cycle."""
|
| 779 |
+
return False
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
def update_dimensions_on_upload(image: Optional[Image.Image]) -> Tuple[int, int]:
|
| 783 |
+
"""Compute recommended (width, height) for the output resolution."""
|
| 784 |
+
if image is None:
|
| 785 |
+
return 1024, 1024
|
| 786 |
+
|
| 787 |
+
original_width, original_height = image.size
|
| 788 |
+
|
| 789 |
+
if original_width > original_height:
|
| 790 |
+
new_width = 1024
|
| 791 |
+
aspect_ratio = original_height / original_width
|
| 792 |
+
new_height = int(new_width * aspect_ratio)
|
| 793 |
+
else:
|
| 794 |
+
new_height = 1024
|
| 795 |
+
aspect_ratio = original_width / original_height
|
| 796 |
+
new_width = int(new_height * aspect_ratio)
|
| 797 |
+
|
| 798 |
+
new_width = (new_width // 8) * 8
|
| 799 |
+
new_height = (new_height // 8) * 8
|
| 800 |
+
|
| 801 |
+
return new_width, new_height
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
with gr.Blocks(css=css, theme=gr.themes.Citrus()) as demo:
|
| 805 |
+
gr.Markdown("""
|
| 806 |
+
## 🎬 Fibo Edit — Camera Angle Control
|
| 807 |
+
|
| 808 |
+
Fibo Edit with Multi-Angle LoRA for precise camera control ✨
|
| 809 |
+
Control rotation, tilt, and zoom to generate images from any angle 🎥
|
| 810 |
+
""")
|
| 811 |
+
|
| 812 |
+
with gr.Row():
|
| 813 |
+
with gr.Column(scale=1):
|
| 814 |
+
image = gr.Image(label="Input Image", type="pil", height=280)
|
| 815 |
+
prev_output = gr.Image(value=None, visible=False)
|
| 816 |
+
is_reset = gr.Checkbox(value=False, visible=False)
|
| 817 |
+
# Hidden state to pass processed image between steps
|
| 818 |
+
processed_image = gr.State(None)
|
| 819 |
+
|
| 820 |
+
gr.Markdown("### 🎮 3D Camera Control")
|
| 821 |
+
|
| 822 |
+
camera_3d = create_camera_3d_component(
|
| 823 |
+
value={"rotate_deg": 0, "zoom": 5.0, "vertical_tilt": 0},
|
| 824 |
+
elem_id="camera-3d-control",
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
with gr.Row():
|
| 828 |
+
reset_btn = gr.Button("🔄 Reset", size="sm")
|
| 829 |
+
run_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
|
| 830 |
+
|
| 831 |
+
with gr.Column(scale=1):
|
| 832 |
+
result = gr.Image(label="Output Image", interactive=False, height=350)
|
| 833 |
+
|
| 834 |
+
gr.Markdown("### 🎚️ Slider Controls")
|
| 835 |
+
|
| 836 |
+
rotate_deg = gr.Slider(
|
| 837 |
+
label="Horizontal Rotation (°)",
|
| 838 |
+
minimum=-180,
|
| 839 |
+
maximum=180,
|
| 840 |
+
step=45,
|
| 841 |
+
value=0,
|
| 842 |
+
info="-180/180: back, -90: left, 0: front, 90: right",
|
| 843 |
+
)
|
| 844 |
+
zoom = gr.Slider(
|
| 845 |
+
label="Zoom Level",
|
| 846 |
+
minimum=0,
|
| 847 |
+
maximum=10,
|
| 848 |
+
step=1,
|
| 849 |
+
value=5.0,
|
| 850 |
+
info="0-3.33: wide, 3.33-6.66: medium, 6.66-10: close-up",
|
| 851 |
+
)
|
| 852 |
+
vertical_tilt = gr.Slider(
|
| 853 |
+
label="Vertical Tilt",
|
| 854 |
+
minimum=-1,
|
| 855 |
+
maximum=1,
|
| 856 |
+
step=0.5,
|
| 857 |
+
value=0,
|
| 858 |
+
info="-1: low-angle, 0: eye-level, 1: high-angle",
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
prompt_preview = gr.Textbox(label="Generated Prompt", interactive=False)
|
| 862 |
+
|
| 863 |
+
with gr.Accordion("📋 Structured Caption (BRIA API)", open=False):
|
| 864 |
+
structured_json = gr.JSON(label="JSON Response", container=False)
|
| 865 |
+
|
| 866 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 867 |
+
seed = gr.Slider(
|
| 868 |
+
label="Seed",
|
| 869 |
+
minimum=0,
|
| 870 |
+
maximum=MAX_SEED,
|
| 871 |
+
step=1,
|
| 872 |
+
value=DEFAULT_SEED,
|
| 873 |
+
)
|
| 874 |
+
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
| 875 |
+
guidance_scale = gr.Slider(
|
| 876 |
+
label="Guidance Scale",
|
| 877 |
+
minimum=1.0,
|
| 878 |
+
maximum=10.0,
|
| 879 |
+
step=0.1,
|
| 880 |
+
value=DEFAULT_GUIDANCE_SCALE,
|
| 881 |
+
)
|
| 882 |
+
num_inference_steps = gr.Slider(
|
| 883 |
+
label="Inference Steps",
|
| 884 |
+
minimum=1,
|
| 885 |
+
maximum=100,
|
| 886 |
+
step=1,
|
| 887 |
+
value=DEFAULT_NUM_INFERENCE_STEPS,
|
| 888 |
+
)
|
| 889 |
+
height = gr.Slider(
|
| 890 |
+
label="Height", minimum=256, maximum=2048, step=8, value=1024
|
| 891 |
+
)
|
| 892 |
+
width = gr.Slider(
|
| 893 |
+
label="Width", minimum=256, maximum=2048, step=8, value=1024
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
# --- Helper Functions ---
|
| 897 |
+
def update_prompt_from_sliders(rotate, zoom_val, tilt):
|
| 898 |
+
prompt, _ = build_camera_prompt(rotate, zoom_val, tilt)
|
| 899 |
+
return prompt
|
| 900 |
+
|
| 901 |
+
def sync_3d_to_sliders(camera_value):
|
| 902 |
+
if camera_value and isinstance(camera_value, dict):
|
| 903 |
+
rot = camera_value.get("rotate_deg", 0)
|
| 904 |
+
zoom_val = camera_value.get("zoom", 5.0)
|
| 905 |
+
tilt = camera_value.get("vertical_tilt", 0)
|
| 906 |
+
prompt, _ = build_camera_prompt(rot, zoom_val, tilt)
|
| 907 |
+
return rot, zoom_val, tilt, prompt
|
| 908 |
+
return gr.update(), gr.update(), gr.update(), gr.update()
|
| 909 |
+
|
| 910 |
+
def sync_sliders_to_3d(rotate, zoom_val, tilt):
|
| 911 |
+
return {"rotate_deg": rotate, "zoom": zoom_val, "vertical_tilt": tilt}
|
| 912 |
+
|
| 913 |
+
def update_3d_image(img):
|
| 914 |
+
if img is None:
|
| 915 |
+
return gr.update(imageUrl=None)
|
| 916 |
+
buffered = BytesIO()
|
| 917 |
+
img.save(buffered, format="PNG")
|
| 918 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 919 |
+
data_url = f"data:image/png;base64,{img_str}"
|
| 920 |
+
return gr.update(imageUrl=data_url)
|
| 921 |
+
|
| 922 |
+
# --- Event Handlers ---
|
| 923 |
+
|
| 924 |
+
# Slider -> Prompt preview
|
| 925 |
+
for slider in [rotate_deg, zoom, vertical_tilt]:
|
| 926 |
+
slider.change(
|
| 927 |
+
fn=update_prompt_from_sliders,
|
| 928 |
+
inputs=[rotate_deg, zoom, vertical_tilt],
|
| 929 |
+
outputs=[prompt_preview],
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
# 3D control -> Sliders + Prompt (no auto-inference)
|
| 933 |
+
camera_3d.change(
|
| 934 |
+
fn=sync_3d_to_sliders,
|
| 935 |
+
inputs=[camera_3d],
|
| 936 |
+
outputs=[rotate_deg, zoom, vertical_tilt, prompt_preview],
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
# Sliders -> 3D control (no auto-inference)
|
| 940 |
+
for slider in [rotate_deg, zoom, vertical_tilt]:
|
| 941 |
+
slider.release(
|
| 942 |
+
fn=sync_sliders_to_3d,
|
| 943 |
+
inputs=[rotate_deg, zoom, vertical_tilt],
|
| 944 |
+
outputs=[camera_3d],
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
# Reset
|
| 948 |
+
reset_btn.click(
|
| 949 |
+
fn=reset_all,
|
| 950 |
+
inputs=None,
|
| 951 |
+
outputs=[rotate_deg, zoom, vertical_tilt, is_reset],
|
| 952 |
+
queue=False,
|
| 953 |
+
).then(fn=end_reset, inputs=None, outputs=[is_reset], queue=False).then(
|
| 954 |
+
fn=sync_sliders_to_3d,
|
| 955 |
+
inputs=[rotate_deg, zoom, vertical_tilt],
|
| 956 |
+
outputs=[camera_3d],
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
# Generate button - Two-stage process
|
| 960 |
+
# Stage 1: Fetch structured caption from BRIA API and display it immediately
|
| 961 |
+
run_event = run_btn.click(
|
| 962 |
+
fn=fetch_structured_caption,
|
| 963 |
+
inputs=[
|
| 964 |
+
image,
|
| 965 |
+
rotate_deg,
|
| 966 |
+
zoom,
|
| 967 |
+
vertical_tilt,
|
| 968 |
+
seed,
|
| 969 |
+
randomize_seed,
|
| 970 |
+
prev_output,
|
| 971 |
+
],
|
| 972 |
+
outputs=[seed, prompt_preview, structured_json, processed_image],
|
| 973 |
+
).then(
|
| 974 |
+
# Stage 2: Generate image with Fibo Edit pipeline
|
| 975 |
+
fn=generate_image_from_caption,
|
| 976 |
+
inputs=[
|
| 977 |
+
processed_image,
|
| 978 |
+
structured_json,
|
| 979 |
+
seed,
|
| 980 |
+
guidance_scale,
|
| 981 |
+
num_inference_steps,
|
| 982 |
+
],
|
| 983 |
+
outputs=[result],
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
# Image upload
|
| 987 |
+
image.upload(
|
| 988 |
+
fn=update_dimensions_on_upload, inputs=[image], outputs=[width, height]
|
| 989 |
+
).then(
|
| 990 |
+
fn=reset_all,
|
| 991 |
+
inputs=None,
|
| 992 |
+
outputs=[rotate_deg, zoom, vertical_tilt, is_reset],
|
| 993 |
+
queue=False,
|
| 994 |
+
).then(fn=end_reset, inputs=None, outputs=[is_reset], queue=False).then(
|
| 995 |
+
fn=update_3d_image, inputs=[image], outputs=[camera_3d]
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
image.clear(fn=lambda: gr.update(imageUrl=None), outputs=[camera_3d])
|
| 999 |
+
|
| 1000 |
+
run_event.then(lambda img, *_: img, inputs=[result], outputs=[prev_output])
|
| 1001 |
+
|
| 1002 |
+
# Examples - Commenting out for now since we need actual example images
|
| 1003 |
+
# Note: With the two-stage inference process, examples would need custom handling
|
| 1004 |
+
# to properly chain fetch_structured_caption -> generate_image_from_caption
|
| 1005 |
+
|
| 1006 |
+
# Sync 3D component when sliders change (covers example loading)
|
| 1007 |
+
def sync_3d_on_slider_change(img, rot, zoom_val, tilt):
|
| 1008 |
+
camera_value = {"rotate_deg": rot, "zoom": zoom_val, "vertical_tilt": tilt}
|
| 1009 |
+
if img is not None:
|
| 1010 |
+
buffered = BytesIO()
|
| 1011 |
+
img.save(buffered, format="PNG")
|
| 1012 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 1013 |
+
data_url = f"data:image/png;base64,{img_str}"
|
| 1014 |
+
return gr.update(value=camera_value, imageUrl=data_url)
|
| 1015 |
+
return gr.update(value=camera_value)
|
| 1016 |
+
|
| 1017 |
+
# When any slider value changes (including from examples), sync the 3D component
|
| 1018 |
+
for slider in [rotate_deg, zoom, vertical_tilt]:
|
| 1019 |
+
slider.change(
|
| 1020 |
+
fn=sync_3d_on_slider_change,
|
| 1021 |
+
inputs=[image, rotate_deg, zoom, vertical_tilt],
|
| 1022 |
+
outputs=[camera_3d],
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
# API endpoints for the two-stage inference process
|
| 1026 |
+
gr.api(fetch_structured_caption, api_name="fetch_caption")
|
| 1027 |
+
gr.api(generate_image_from_caption, api_name="generate_image")
|
| 1028 |
+
|
| 1029 |
+
if __name__ == "__main__":
|
| 1030 |
+
head = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
|
| 1031 |
+
|
| 1032 |
+
if RUN_LOCAL:
|
| 1033 |
+
# Local development configuration
|
| 1034 |
+
demo.launch(
|
| 1035 |
+
mcp_server=True,
|
| 1036 |
+
head=head,
|
| 1037 |
+
footer_links=["api", "gradio", "settings"],
|
| 1038 |
+
server_name="0.0.0.0",
|
| 1039 |
+
server_port=8081,
|
| 1040 |
+
)
|
| 1041 |
+
else:
|
| 1042 |
+
# HuggingFace Spaces standard configuration
|
| 1043 |
+
demo.launch(head=head)
|
fibo_edit_pipeline.py
ADDED
|
@@ -0,0 +1,953 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Bria.ai. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
|
| 4 |
+
# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
|
| 5 |
+
#
|
| 6 |
+
# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
|
| 7 |
+
# indicate if changes were made, and do not use the material for commercial purposes.
|
| 8 |
+
#
|
| 9 |
+
# See the license for further details.
|
| 10 |
+
|
| 11 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from transformers import AutoTokenizer
|
| 16 |
+
from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
|
| 17 |
+
import PIL
|
| 18 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 19 |
+
from diffusers.loaders import FluxLoraLoaderMixin
|
| 20 |
+
from diffusers.models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
|
| 21 |
+
from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
|
| 22 |
+
from diffusers.pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
|
| 23 |
+
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
| 24 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 25 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
|
| 26 |
+
from diffusers.utils import (
|
| 27 |
+
USE_PEFT_BACKEND,
|
| 28 |
+
is_torch_xla_available,
|
| 29 |
+
logging,
|
| 30 |
+
replace_example_docstring,
|
| 31 |
+
scale_lora_layers,
|
| 32 |
+
unscale_lora_layers,
|
| 33 |
+
)
|
| 34 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if is_torch_xla_available():
|
| 38 |
+
import torch_xla.core.xla_model as xm
|
| 39 |
+
|
| 40 |
+
XLA_AVAILABLE = True
|
| 41 |
+
else:
|
| 42 |
+
XLA_AVAILABLE = False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 46 |
+
|
| 47 |
+
EXAMPLE_DOC_STRING = """
|
| 48 |
+
Example:
|
| 49 |
+
```python
|
| 50 |
+
import torch
|
| 51 |
+
from diffusers import BriaFiboPipeline
|
| 52 |
+
from diffusers.modular_pipelines import ModularPipeline
|
| 53 |
+
|
| 54 |
+
torch.set_grad_enabled(False)
|
| 55 |
+
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
|
| 56 |
+
|
| 57 |
+
pipe = BriaFiboPipeline.from_pretrained(
|
| 58 |
+
"briaai/FIBO",
|
| 59 |
+
trust_remote_code=True,
|
| 60 |
+
torch_dtype=torch.bfloat16,
|
| 61 |
+
)
|
| 62 |
+
pipe.enable_model_cpu_offload()
|
| 63 |
+
|
| 64 |
+
with torch.inference_mode():
|
| 65 |
+
# 1. Create a prompt to generate an initial image
|
| 66 |
+
output = vlm_pipe(prompt="a beautiful dog")
|
| 67 |
+
json_prompt_generate = output.values["json_prompt"]
|
| 68 |
+
|
| 69 |
+
# Generate the image from the structured json prompt
|
| 70 |
+
results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
|
| 71 |
+
results_generate.images[0].save("image_generate.png")
|
| 72 |
+
```
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
PREFERRED_RESOLUTION = {
|
| 76 |
+
256 * 256: [(208, 304), (224, 288), (256, 256), (288, 224), (304, 208), (320, 192), (336, 192)],
|
| 77 |
+
512 * 512: [
|
| 78 |
+
(416, 624),
|
| 79 |
+
(432, 592),
|
| 80 |
+
(464, 560),
|
| 81 |
+
(512, 512),
|
| 82 |
+
(544, 480),
|
| 83 |
+
(576, 448),
|
| 84 |
+
(592, 432),
|
| 85 |
+
(608, 416),
|
| 86 |
+
(624, 416),
|
| 87 |
+
(640, 400),
|
| 88 |
+
(672, 384),
|
| 89 |
+
(704, 368),
|
| 90 |
+
],
|
| 91 |
+
1024 * 1024: [
|
| 92 |
+
(832, 1248),
|
| 93 |
+
(880, 1184),
|
| 94 |
+
(912, 1136),
|
| 95 |
+
(1024, 1024),
|
| 96 |
+
(1136, 912),
|
| 97 |
+
(1184, 880),
|
| 98 |
+
(1216, 848),
|
| 99 |
+
(1248, 832),
|
| 100 |
+
(1248, 832),
|
| 101 |
+
(1264, 816),
|
| 102 |
+
(1296, 800),
|
| 103 |
+
(1360, 768),
|
| 104 |
+
],
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class BriaFiboEditPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
| 109 |
+
r"""
|
| 110 |
+
Args:
|
| 111 |
+
transformer (`BriaFiboTransformer2DModel`):
|
| 112 |
+
The transformer model for 2D diffusion modeling.
|
| 113 |
+
scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
|
| 114 |
+
Scheduler to be used with `transformer` to denoise the encoded latents.
|
| 115 |
+
vae (`AutoencoderKLWan`):
|
| 116 |
+
Variational Auto-Encoder for encoding and decoding images to and from latent representations.
|
| 117 |
+
text_encoder (`SmolLM3ForCausalLM`):
|
| 118 |
+
Text encoder for processing input prompts.
|
| 119 |
+
tokenizer (`AutoTokenizer`):
|
| 120 |
+
Tokenizer used for processing the input text prompts for the text_encoder.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
| 124 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
transformer: BriaFiboTransformer2DModel,
|
| 129 |
+
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
|
| 130 |
+
vae: AutoencoderKLWan,
|
| 131 |
+
text_encoder: SmolLM3ForCausalLM,
|
| 132 |
+
tokenizer: AutoTokenizer,
|
| 133 |
+
):
|
| 134 |
+
self.register_modules(
|
| 135 |
+
vae=vae,
|
| 136 |
+
text_encoder=text_encoder,
|
| 137 |
+
tokenizer=tokenizer,
|
| 138 |
+
transformer=transformer,
|
| 139 |
+
scheduler=scheduler,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
self.vae_scale_factor = 16
|
| 143 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # * 2)
|
| 144 |
+
self.default_sample_size = 32 # 64
|
| 145 |
+
|
| 146 |
+
def get_prompt_embeds(
|
| 147 |
+
self,
|
| 148 |
+
prompt: Union[str, List[str]],
|
| 149 |
+
num_images_per_prompt: int = 1,
|
| 150 |
+
max_sequence_length: int = 2048,
|
| 151 |
+
device: Optional[torch.device] = None,
|
| 152 |
+
dtype: Optional[torch.dtype] = None,
|
| 153 |
+
):
|
| 154 |
+
device = device or self._execution_device
|
| 155 |
+
dtype = dtype or self.text_encoder.dtype
|
| 156 |
+
|
| 157 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 158 |
+
if not prompt:
|
| 159 |
+
raise ValueError("`prompt` must be a non-empty string or list of strings.")
|
| 160 |
+
|
| 161 |
+
batch_size = len(prompt)
|
| 162 |
+
bot_token_id = 128000
|
| 163 |
+
|
| 164 |
+
text_encoder_device = device if device is not None else torch.device("cpu")
|
| 165 |
+
if not isinstance(text_encoder_device, torch.device):
|
| 166 |
+
text_encoder_device = torch.device(text_encoder_device)
|
| 167 |
+
|
| 168 |
+
if all(p == "" for p in prompt):
|
| 169 |
+
input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
|
| 170 |
+
attention_mask = torch.ones_like(input_ids)
|
| 171 |
+
else:
|
| 172 |
+
tokenized = self.tokenizer(
|
| 173 |
+
prompt,
|
| 174 |
+
padding="longest",
|
| 175 |
+
max_length=max_sequence_length,
|
| 176 |
+
truncation=True,
|
| 177 |
+
add_special_tokens=True,
|
| 178 |
+
return_tensors="pt",
|
| 179 |
+
)
|
| 180 |
+
input_ids = tokenized.input_ids.to(text_encoder_device)
|
| 181 |
+
attention_mask = tokenized.attention_mask.to(text_encoder_device)
|
| 182 |
+
|
| 183 |
+
if any(p == "" for p in prompt):
|
| 184 |
+
empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
|
| 185 |
+
input_ids[empty_rows] = bot_token_id
|
| 186 |
+
attention_mask[empty_rows] = 1
|
| 187 |
+
|
| 188 |
+
encoder_outputs = self.text_encoder(
|
| 189 |
+
input_ids,
|
| 190 |
+
attention_mask=attention_mask,
|
| 191 |
+
output_hidden_states=True,
|
| 192 |
+
)
|
| 193 |
+
hidden_states = encoder_outputs.hidden_states
|
| 194 |
+
|
| 195 |
+
prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
|
| 196 |
+
prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
|
| 197 |
+
|
| 198 |
+
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 199 |
+
hidden_states = tuple(
|
| 200 |
+
layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
|
| 201 |
+
)
|
| 202 |
+
attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
|
| 203 |
+
|
| 204 |
+
return prompt_embeds, hidden_states, attention_mask
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
|
| 208 |
+
# Pad embeddings to `max_tokens` while preserving the mask of real tokens.
|
| 209 |
+
batch_size, seq_len, dim = prompt_embeds.shape
|
| 210 |
+
|
| 211 |
+
if attention_mask is None:
|
| 212 |
+
attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
|
| 213 |
+
else:
|
| 214 |
+
attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
| 215 |
+
|
| 216 |
+
if max_tokens < seq_len:
|
| 217 |
+
raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
|
| 218 |
+
|
| 219 |
+
if max_tokens > seq_len:
|
| 220 |
+
pad_length = max_tokens - seq_len
|
| 221 |
+
padding = torch.zeros((batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
|
| 222 |
+
prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
|
| 223 |
+
|
| 224 |
+
mask_padding = torch.zeros((batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
|
| 225 |
+
attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
|
| 226 |
+
|
| 227 |
+
return prompt_embeds, attention_mask
|
| 228 |
+
|
| 229 |
+
def encode_prompt(
|
| 230 |
+
self,
|
| 231 |
+
prompt: Union[str, List[str]],
|
| 232 |
+
device: Optional[torch.device] = None,
|
| 233 |
+
num_images_per_prompt: int = 1,
|
| 234 |
+
guidance_scale: float = 5,
|
| 235 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 236 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 237 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 238 |
+
max_sequence_length: int = 3000,
|
| 239 |
+
lora_scale: Optional[float] = None,
|
| 240 |
+
):
|
| 241 |
+
r"""
|
| 242 |
+
Args:
|
| 243 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 244 |
+
prompt to be encoded
|
| 245 |
+
device: (`torch.device`):
|
| 246 |
+
torch device
|
| 247 |
+
num_images_per_prompt (`int`):
|
| 248 |
+
number of images that should be generated per prompt
|
| 249 |
+
guidance_scale (`float`):
|
| 250 |
+
Guidance scale for classifier free guidance.
|
| 251 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 252 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 253 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 254 |
+
less than `1`).
|
| 255 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 256 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 257 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 258 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 259 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 260 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 261 |
+
argument.
|
| 262 |
+
"""
|
| 263 |
+
device = device or self._execution_device
|
| 264 |
+
|
| 265 |
+
# set lora scale so that monkey patched LoRA
|
| 266 |
+
# function of text encoder can correctly access it
|
| 267 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
| 268 |
+
self._lora_scale = lora_scale
|
| 269 |
+
|
| 270 |
+
# dynamically adjust the LoRA scale
|
| 271 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 272 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 273 |
+
|
| 274 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 275 |
+
if prompt is not None:
|
| 276 |
+
batch_size = len(prompt)
|
| 277 |
+
else:
|
| 278 |
+
batch_size = prompt_embeds.shape[0]
|
| 279 |
+
|
| 280 |
+
prompt_attention_mask = None
|
| 281 |
+
negative_prompt_attention_mask = None
|
| 282 |
+
if prompt_embeds is None:
|
| 283 |
+
prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
|
| 284 |
+
prompt=prompt,
|
| 285 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 286 |
+
max_sequence_length=max_sequence_length,
|
| 287 |
+
device=device,
|
| 288 |
+
)
|
| 289 |
+
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
|
| 290 |
+
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
|
| 291 |
+
|
| 292 |
+
if guidance_scale > 1:
|
| 293 |
+
if isinstance(negative_prompt, list) and negative_prompt[0] is None:
|
| 294 |
+
negative_prompt = ""
|
| 295 |
+
negative_prompt = negative_prompt or ""
|
| 296 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 297 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 298 |
+
raise TypeError(
|
| 299 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 300 |
+
f" {type(prompt)}."
|
| 301 |
+
)
|
| 302 |
+
elif batch_size != len(negative_prompt):
|
| 303 |
+
raise ValueError(
|
| 304 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 305 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 306 |
+
" the batch size of `prompt`."
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
|
| 310 |
+
prompt=negative_prompt,
|
| 311 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 312 |
+
max_sequence_length=max_sequence_length,
|
| 313 |
+
device=device,
|
| 314 |
+
)
|
| 315 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
|
| 316 |
+
negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
|
| 317 |
+
|
| 318 |
+
if self.text_encoder is not None:
|
| 319 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 320 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 321 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 322 |
+
|
| 323 |
+
# Pad to longest
|
| 324 |
+
if prompt_attention_mask is not None:
|
| 325 |
+
prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
| 326 |
+
|
| 327 |
+
if negative_prompt_embeds is not None:
|
| 328 |
+
if negative_prompt_attention_mask is not None:
|
| 329 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(
|
| 330 |
+
device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
|
| 331 |
+
)
|
| 332 |
+
max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
|
| 333 |
+
|
| 334 |
+
prompt_embeds, prompt_attention_mask = self.pad_embedding(
|
| 335 |
+
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
|
| 336 |
+
)
|
| 337 |
+
prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
|
| 338 |
+
|
| 339 |
+
negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
|
| 340 |
+
negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
|
| 341 |
+
)
|
| 342 |
+
negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
|
| 343 |
+
else:
|
| 344 |
+
max_tokens = prompt_embeds.shape[1]
|
| 345 |
+
prompt_embeds, prompt_attention_mask = self.pad_embedding(
|
| 346 |
+
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
|
| 347 |
+
)
|
| 348 |
+
negative_prompt_layers = None
|
| 349 |
+
|
| 350 |
+
dtype = self.text_encoder.dtype
|
| 351 |
+
text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
|
| 352 |
+
|
| 353 |
+
return (
|
| 354 |
+
prompt_embeds,
|
| 355 |
+
negative_prompt_embeds,
|
| 356 |
+
text_ids,
|
| 357 |
+
prompt_attention_mask,
|
| 358 |
+
negative_prompt_attention_mask,
|
| 359 |
+
prompt_layers,
|
| 360 |
+
negative_prompt_layers,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
@property
|
| 364 |
+
def guidance_scale(self):
|
| 365 |
+
return self._guidance_scale
|
| 366 |
+
|
| 367 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 368 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 369 |
+
# corresponds to doing no classifier free guidance.
|
| 370 |
+
|
| 371 |
+
@property
|
| 372 |
+
def joint_attention_kwargs(self):
|
| 373 |
+
return self._joint_attention_kwargs
|
| 374 |
+
|
| 375 |
+
@property
|
| 376 |
+
def num_timesteps(self):
|
| 377 |
+
return self._num_timesteps
|
| 378 |
+
|
| 379 |
+
@property
|
| 380 |
+
def interrupt(self):
|
| 381 |
+
return self._interrupt
|
| 382 |
+
|
| 383 |
+
@staticmethod
|
| 384 |
+
# Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
| 385 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 386 |
+
batch_size, num_patches, channels = latents.shape
|
| 387 |
+
|
| 388 |
+
height = height // vae_scale_factor
|
| 389 |
+
width = width // vae_scale_factor
|
| 390 |
+
|
| 391 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 392 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 393 |
+
|
| 394 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
| 395 |
+
return latents
|
| 396 |
+
|
| 397 |
+
@staticmethod
|
| 398 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
| 399 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
| 400 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
| 401 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
| 402 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
| 403 |
+
|
| 404 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 405 |
+
|
| 406 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 407 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 411 |
+
|
| 412 |
+
@staticmethod
|
| 413 |
+
def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
|
| 414 |
+
batch_size, num_patches, channels = latents.shape
|
| 415 |
+
|
| 416 |
+
height = height // vae_scale_factor
|
| 417 |
+
width = width // vae_scale_factor
|
| 418 |
+
|
| 419 |
+
latents = latents.view(batch_size, height, width, channels)
|
| 420 |
+
latents = latents.permute(0, 3, 1, 2)
|
| 421 |
+
|
| 422 |
+
return latents
|
| 423 |
+
|
| 424 |
+
@staticmethod
|
| 425 |
+
def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
|
| 426 |
+
latents = latents.permute(0, 2, 3, 1)
|
| 427 |
+
latents = latents.reshape(batch_size, height * width, num_channels_latents)
|
| 428 |
+
return latents
|
| 429 |
+
|
| 430 |
+
@staticmethod
|
| 431 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
| 432 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 433 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 434 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 435 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 436 |
+
|
| 437 |
+
return latents
|
| 438 |
+
|
| 439 |
+
def prepare_latents(
|
| 440 |
+
self,
|
| 441 |
+
batch_size,
|
| 442 |
+
num_channels_latents,
|
| 443 |
+
height,
|
| 444 |
+
width,
|
| 445 |
+
dtype,
|
| 446 |
+
device,
|
| 447 |
+
generator,
|
| 448 |
+
latents=None,
|
| 449 |
+
do_patching=False,
|
| 450 |
+
):
|
| 451 |
+
height = int(height) // self.vae_scale_factor
|
| 452 |
+
width = int(width) // self.vae_scale_factor
|
| 453 |
+
|
| 454 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 455 |
+
|
| 456 |
+
if latents is not None:
|
| 457 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
| 458 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
| 459 |
+
|
| 460 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 461 |
+
raise ValueError(
|
| 462 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 463 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 467 |
+
if do_patching:
|
| 468 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 469 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 470 |
+
else:
|
| 471 |
+
latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
|
| 472 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
| 473 |
+
|
| 474 |
+
return latents, latent_image_ids
|
| 475 |
+
|
| 476 |
+
@staticmethod
|
| 477 |
+
def _prepare_attention_mask(attention_mask):
|
| 478 |
+
attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
|
| 479 |
+
|
| 480 |
+
# convert to 0 - keep, -inf ignore
|
| 481 |
+
attention_matrix = torch.where(
|
| 482 |
+
attention_matrix == 1, 0.0, -torch.inf
|
| 483 |
+
) # Apply -inf to ignored tokens for nulling softmax score
|
| 484 |
+
return attention_matrix
|
| 485 |
+
|
| 486 |
+
@torch.no_grad()
|
| 487 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 488 |
+
def __call__(
|
| 489 |
+
self,
|
| 490 |
+
prompt: Union[str, List[str]] = None,
|
| 491 |
+
image: Optional[Union[PIL.Image.Image, torch.FloatTensor]] = None,
|
| 492 |
+
num_inference_steps: int = 30,
|
| 493 |
+
timesteps: List[int] = None,
|
| 494 |
+
guidance_scale: float = 5,
|
| 495 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 496 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 497 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 498 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 499 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 500 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 501 |
+
output_type: Optional[str] = "pil",
|
| 502 |
+
return_dict: bool = True,
|
| 503 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 504 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 505 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 506 |
+
max_sequence_length: int = 3000,
|
| 507 |
+
do_patching=False,
|
| 508 |
+
_auto_resize: bool = True,
|
| 509 |
+
base_resolution: int = 1024,
|
| 510 |
+
):
|
| 511 |
+
r"""
|
| 512 |
+
Function invoked when calling the pipeline for generation.
|
| 513 |
+
|
| 514 |
+
Args:
|
| 515 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 516 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 517 |
+
instead.
|
| 518 |
+
image (`PIL.Image.Image` or `torch.FloatTensor`, *optional*):
|
| 519 |
+
The image to guide the image generation. If not defined, the pipeline will generate an image from scratch.
|
| 520 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 521 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 522 |
+
expense of slower inference.
|
| 523 |
+
timesteps (`List[int]`, *optional*):
|
| 524 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 525 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 526 |
+
passed will be used. Must be in descending order.
|
| 527 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 528 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 529 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 530 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 531 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 532 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 533 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 534 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 535 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 536 |
+
less than `1`).
|
| 537 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 538 |
+
The number of images to generate per prompt.
|
| 539 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 540 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 541 |
+
to make generation deterministic.
|
| 542 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 543 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 544 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 545 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 546 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 547 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 548 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 549 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 550 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 551 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 552 |
+
argument.
|
| 553 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 554 |
+
The output format of the generate image. Choose between
|
| 555 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 556 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 557 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 558 |
+
of a plain tuple.
|
| 559 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 560 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 561 |
+
`self.processor` in
|
| 562 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 563 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 564 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 565 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 566 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 567 |
+
`callback_on_step_end_tensor_inputs`.
|
| 568 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 569 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 570 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 571 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 572 |
+
max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
|
| 573 |
+
do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
|
| 574 |
+
Examples:
|
| 575 |
+
Returns:
|
| 576 |
+
[`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
|
| 577 |
+
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
| 578 |
+
generated images.
|
| 579 |
+
"""
|
| 580 |
+
|
| 581 |
+
if image is not None and _auto_resize:
|
| 582 |
+
image_height, image_width = self.image_processor.get_default_height_width(image)
|
| 583 |
+
# area = min(prefered_resolutions.keys(),key=lambda size: abs(image_height*image_width-size))
|
| 584 |
+
image_width, image_height = min(
|
| 585 |
+
PREFERRED_RESOLUTION[base_resolution * base_resolution],
|
| 586 |
+
key=lambda size: abs(size[0] / size[1] - image_width / image_height),
|
| 587 |
+
)
|
| 588 |
+
width, height = image_width, image_height
|
| 589 |
+
|
| 590 |
+
# 1. Check inputs. Raise error if not correct
|
| 591 |
+
self.check_inputs( # check flux
|
| 592 |
+
prompt=prompt,
|
| 593 |
+
prompt_embeds=prompt_embeds,
|
| 594 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 595 |
+
max_sequence_length=max_sequence_length,
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
self._guidance_scale = guidance_scale
|
| 599 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 600 |
+
self._interrupt = False
|
| 601 |
+
|
| 602 |
+
# 2. Define call parameters
|
| 603 |
+
if prompt is not None and isinstance(prompt, str):
|
| 604 |
+
batch_size = 1
|
| 605 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 606 |
+
batch_size = len(prompt)
|
| 607 |
+
else:
|
| 608 |
+
batch_size = prompt_embeds.shape[0]
|
| 609 |
+
|
| 610 |
+
device = self._execution_device
|
| 611 |
+
|
| 612 |
+
lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 613 |
+
|
| 614 |
+
(
|
| 615 |
+
prompt_embeds,
|
| 616 |
+
negative_prompt_embeds,
|
| 617 |
+
text_ids,
|
| 618 |
+
prompt_attention_mask,
|
| 619 |
+
negative_prompt_attention_mask,
|
| 620 |
+
prompt_layers,
|
| 621 |
+
negative_prompt_layers,
|
| 622 |
+
) = self.encode_prompt(
|
| 623 |
+
prompt=prompt,
|
| 624 |
+
negative_prompt=negative_prompt,
|
| 625 |
+
guidance_scale=guidance_scale,
|
| 626 |
+
prompt_embeds=prompt_embeds,
|
| 627 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 628 |
+
device=device,
|
| 629 |
+
max_sequence_length=max_sequence_length,
|
| 630 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 631 |
+
lora_scale=lora_scale,
|
| 632 |
+
)
|
| 633 |
+
prompt_batch_size = prompt_embeds.shape[0]
|
| 634 |
+
|
| 635 |
+
if guidance_scale > 1:
|
| 636 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 637 |
+
prompt_layers = [
|
| 638 |
+
torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
|
| 639 |
+
]
|
| 640 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
| 641 |
+
|
| 642 |
+
total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
|
| 643 |
+
self.transformer.single_transformer_blocks
|
| 644 |
+
)
|
| 645 |
+
if len(prompt_layers) >= total_num_layers_transformer:
|
| 646 |
+
# remove first layers
|
| 647 |
+
prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
|
| 648 |
+
else:
|
| 649 |
+
# duplicate last layer
|
| 650 |
+
prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
|
| 651 |
+
|
| 652 |
+
# Preprocess image
|
| 653 |
+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 654 |
+
image = self.image_processor.resize(image, height, width)
|
| 655 |
+
image = self.image_processor.preprocess(image, height, width)
|
| 656 |
+
|
| 657 |
+
# 5. Prepare latent variables
|
| 658 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 659 |
+
if do_patching:
|
| 660 |
+
num_channels_latents = int(num_channels_latents / 4)
|
| 661 |
+
|
| 662 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 663 |
+
prompt_batch_size,
|
| 664 |
+
num_channels_latents,
|
| 665 |
+
height,
|
| 666 |
+
width,
|
| 667 |
+
prompt_embeds.dtype,
|
| 668 |
+
device,
|
| 669 |
+
generator,
|
| 670 |
+
latents,
|
| 671 |
+
do_patching,
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
if image is not None:
|
| 675 |
+
image_latents, image_ids = self.prepare_image_latents(
|
| 676 |
+
image=image,
|
| 677 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 678 |
+
num_channels_latents=num_channels_latents,
|
| 679 |
+
height=height,
|
| 680 |
+
width=width,
|
| 681 |
+
dtype=prompt_embeds.dtype,
|
| 682 |
+
device=device,
|
| 683 |
+
generator=generator,
|
| 684 |
+
)
|
| 685 |
+
latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0) # dim 0 is sequence dimension
|
| 686 |
+
else:
|
| 687 |
+
image_latents = None
|
| 688 |
+
|
| 689 |
+
latent_attention_mask = torch.ones(
|
| 690 |
+
[latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
|
| 691 |
+
)
|
| 692 |
+
if guidance_scale > 1:
|
| 693 |
+
latent_attention_mask = latent_attention_mask.repeat(2, 1)
|
| 694 |
+
|
| 695 |
+
if image_latents is None:
|
| 696 |
+
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
|
| 697 |
+
else:
|
| 698 |
+
image_latent_attention_mask = torch.ones(
|
| 699 |
+
[image_latents.shape[0], image_latents.shape[1]],
|
| 700 |
+
dtype=image_latents.dtype,
|
| 701 |
+
device=image_latents.device,
|
| 702 |
+
)
|
| 703 |
+
if guidance_scale > 1:
|
| 704 |
+
image_latent_attention_mask = image_latent_attention_mask.repeat(2, 1)
|
| 705 |
+
attention_mask = torch.cat(
|
| 706 |
+
[prompt_attention_mask, latent_attention_mask, image_latent_attention_mask], dim=1
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq
|
| 710 |
+
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
|
| 711 |
+
|
| 712 |
+
if self._joint_attention_kwargs is None:
|
| 713 |
+
self._joint_attention_kwargs = {}
|
| 714 |
+
self._joint_attention_kwargs["attention_mask"] = attention_mask
|
| 715 |
+
|
| 716 |
+
# Adapt scheduler to dynamic shifting (resolution dependent)
|
| 717 |
+
|
| 718 |
+
if do_patching:
|
| 719 |
+
seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
|
| 720 |
+
else:
|
| 721 |
+
seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
|
| 722 |
+
|
| 723 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 724 |
+
|
| 725 |
+
mu = calculate_shift(
|
| 726 |
+
seq_len,
|
| 727 |
+
self.scheduler.config.base_image_seq_len,
|
| 728 |
+
self.scheduler.config.max_image_seq_len,
|
| 729 |
+
self.scheduler.config.base_shift,
|
| 730 |
+
self.scheduler.config.max_shift,
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# Init sigmas and timesteps according to shift size
|
| 734 |
+
# This changes the scheduler in-place according to the dynamic scheduling
|
| 735 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 736 |
+
self.scheduler,
|
| 737 |
+
num_inference_steps=num_inference_steps,
|
| 738 |
+
device=device,
|
| 739 |
+
timesteps=None,
|
| 740 |
+
sigmas=sigmas,
|
| 741 |
+
mu=mu,
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 745 |
+
self._num_timesteps = len(timesteps)
|
| 746 |
+
|
| 747 |
+
# Support old different diffusers versions
|
| 748 |
+
if len(latent_image_ids.shape) == 3:
|
| 749 |
+
latent_image_ids = latent_image_ids[0]
|
| 750 |
+
|
| 751 |
+
if len(text_ids.shape) == 3:
|
| 752 |
+
text_ids = text_ids[0]
|
| 753 |
+
|
| 754 |
+
# 6. Denoising loop
|
| 755 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 756 |
+
for i, t in enumerate(timesteps):
|
| 757 |
+
if self.interrupt:
|
| 758 |
+
continue
|
| 759 |
+
|
| 760 |
+
latent_model_input = latents
|
| 761 |
+
|
| 762 |
+
if image_latents is not None:
|
| 763 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
|
| 764 |
+
|
| 765 |
+
# expand the latents if we are doing classifier free guidance
|
| 766 |
+
latent_model_input = torch.cat([latent_model_input] * 2) if guidance_scale > 1 else latent_model_input
|
| 767 |
+
|
| 768 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 769 |
+
timestep = t.expand(latent_model_input.shape[0]).to(
|
| 770 |
+
device=latent_model_input.device, dtype=latent_model_input.dtype
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
# This is predicts "v" from flow-matching or eps from diffusion
|
| 774 |
+
noise_pred = self.transformer(
|
| 775 |
+
hidden_states=latent_model_input,
|
| 776 |
+
timestep=timestep,
|
| 777 |
+
encoder_hidden_states=prompt_embeds,
|
| 778 |
+
text_encoder_layers=prompt_layers,
|
| 779 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 780 |
+
return_dict=False,
|
| 781 |
+
txt_ids=text_ids,
|
| 782 |
+
img_ids=latent_image_ids,
|
| 783 |
+
)[0]
|
| 784 |
+
|
| 785 |
+
# perform guidance
|
| 786 |
+
if guidance_scale > 1:
|
| 787 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 788 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 789 |
+
|
| 790 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 791 |
+
latents_dtype = latents.dtype
|
| 792 |
+
latents = self.scheduler.step(noise_pred[:, : latents.shape[1], ...], t, latents, return_dict=False)[0]
|
| 793 |
+
|
| 794 |
+
if latents.dtype != latents_dtype:
|
| 795 |
+
if torch.backends.mps.is_available():
|
| 796 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 797 |
+
latents = latents.to(latents_dtype)
|
| 798 |
+
|
| 799 |
+
if callback_on_step_end is not None:
|
| 800 |
+
callback_kwargs = {}
|
| 801 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 802 |
+
callback_kwargs[k] = locals()[k]
|
| 803 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 804 |
+
|
| 805 |
+
latents = callback_outputs.pop("latents", latents)
|
| 806 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 807 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 808 |
+
|
| 809 |
+
# call the callback, if provided
|
| 810 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 811 |
+
progress_bar.update()
|
| 812 |
+
|
| 813 |
+
if XLA_AVAILABLE:
|
| 814 |
+
xm.mark_step()
|
| 815 |
+
|
| 816 |
+
if output_type == "latent":
|
| 817 |
+
image = latents
|
| 818 |
+
|
| 819 |
+
else:
|
| 820 |
+
if do_patching:
|
| 821 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 822 |
+
else:
|
| 823 |
+
latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
|
| 824 |
+
|
| 825 |
+
latents = latents.unsqueeze(dim=2)
|
| 826 |
+
latents_device = latents[0].device
|
| 827 |
+
latents_dtype = latents[0].dtype
|
| 828 |
+
latents_mean = (
|
| 829 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 830 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 831 |
+
.to(latents_device, latents_dtype)
|
| 832 |
+
)
|
| 833 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 834 |
+
latents_device, latents_dtype
|
| 835 |
+
)
|
| 836 |
+
latents_scaled = [latent / latents_std + latents_mean for latent in latents]
|
| 837 |
+
latents_scaled = torch.cat(latents_scaled, dim=0)
|
| 838 |
+
image = []
|
| 839 |
+
for scaled_latent in latents_scaled:
|
| 840 |
+
curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
|
| 841 |
+
curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
|
| 842 |
+
image.append(curr_image)
|
| 843 |
+
if len(image) == 1:
|
| 844 |
+
image = image[0]
|
| 845 |
+
else:
|
| 846 |
+
image = np.stack(image, axis=0)
|
| 847 |
+
|
| 848 |
+
# Offload all models
|
| 849 |
+
self.maybe_free_model_hooks()
|
| 850 |
+
|
| 851 |
+
if not return_dict:
|
| 852 |
+
return (image,)
|
| 853 |
+
|
| 854 |
+
return BriaFiboPipelineOutput(images=image)
|
| 855 |
+
|
| 856 |
+
def prepare_image_latents(
|
| 857 |
+
self,
|
| 858 |
+
image: torch.Tensor,
|
| 859 |
+
batch_size: int,
|
| 860 |
+
num_channels_latents: int,
|
| 861 |
+
height: int,
|
| 862 |
+
width: int,
|
| 863 |
+
dtype: torch.dtype,
|
| 864 |
+
device: torch.device,
|
| 865 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 866 |
+
):
|
| 867 |
+
image = image.to(device=device, dtype=dtype)
|
| 868 |
+
|
| 869 |
+
height = int(height) // self.vae_scale_factor
|
| 870 |
+
width = int(width) // self.vae_scale_factor
|
| 871 |
+
|
| 872 |
+
# scaling
|
| 873 |
+
latents_mean = (
|
| 874 |
+
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
|
| 875 |
+
)
|
| 876 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 877 |
+
device, dtype
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
image_latents_cthw = self.vae.encode(image.unsqueeze(2)).latent_dist.mean
|
| 881 |
+
latents_scaled = [(latent - latents_mean) * latents_std for latent in image_latents_cthw]
|
| 882 |
+
image_latents_cthw = torch.concat(latents_scaled, dim=0)
|
| 883 |
+
image_latents_bchw = image_latents_cthw[:, :, 0, :, :]
|
| 884 |
+
|
| 885 |
+
image_latent_height, image_latent_width = image_latents_bchw.shape[2:]
|
| 886 |
+
image_latents_bsd = self._pack_latents_no_patch(
|
| 887 |
+
latents=image_latents_bchw,
|
| 888 |
+
batch_size=batch_size,
|
| 889 |
+
num_channels_latents=num_channels_latents,
|
| 890 |
+
height=image_latent_height,
|
| 891 |
+
width=image_latent_width,
|
| 892 |
+
)
|
| 893 |
+
# breakpoint()
|
| 894 |
+
image_ids = self._prepare_latent_image_ids(
|
| 895 |
+
batch_size=batch_size, height=image_latent_height, width=image_latent_width, device=device, dtype=dtype
|
| 896 |
+
)
|
| 897 |
+
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
|
| 898 |
+
image_ids[..., 0] = 1
|
| 899 |
+
return image_latents_bsd, image_ids
|
| 900 |
+
|
| 901 |
+
def check_inputs(
|
| 902 |
+
self,
|
| 903 |
+
prompt,
|
| 904 |
+
negative_prompt=None,
|
| 905 |
+
prompt_embeds=None,
|
| 906 |
+
negative_prompt_embeds=None,
|
| 907 |
+
callback_on_step_end_tensor_inputs=None,
|
| 908 |
+
max_sequence_length=None,
|
| 909 |
+
):
|
| 910 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 911 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 912 |
+
):
|
| 913 |
+
raise ValueError(
|
| 914 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
if prompt is not None and prompt_embeds is not None:
|
| 918 |
+
raise ValueError(
|
| 919 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 920 |
+
" only forward one of the two."
|
| 921 |
+
)
|
| 922 |
+
elif prompt is None and prompt_embeds is None:
|
| 923 |
+
raise ValueError(
|
| 924 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 925 |
+
)
|
| 926 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 927 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 928 |
+
|
| 929 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 930 |
+
raise ValueError(
|
| 931 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 932 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 936 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 937 |
+
raise ValueError(
|
| 938 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 939 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 940 |
+
f" {negative_prompt_embeds.shape}."
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
if max_sequence_length is not None and max_sequence_length > 3000:
|
| 944 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
|
| 945 |
+
|
| 946 |
+
def create_attention_matrix(self, attention_mask):
|
| 947 |
+
attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
|
| 948 |
+
|
| 949 |
+
# convert to 0 - keep, -inf ignore
|
| 950 |
+
attention_matrix = torch.where(
|
| 951 |
+
attention_matrix == 1, 0.0, -torch.inf
|
| 952 |
+
) # Apply -inf to ignored tokens for nulling softmax score
|
| 953 |
+
return attention_matrix
|
requirements.txt
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.12.0
|
| 2 |
+
aiofiles==24.1.0
|
| 3 |
+
annotated-doc==0.0.4
|
| 4 |
+
annotated-types==0.7.0
|
| 5 |
+
anyio==4.12.1
|
| 6 |
+
asttokens==3.0.1
|
| 7 |
+
attrs==25.4.0
|
| 8 |
+
boto3==1.42.28
|
| 9 |
+
botocore==1.42.28
|
| 10 |
+
brotli==1.2.0
|
| 11 |
+
certifi==2026.1.4
|
| 12 |
+
cffi==2.0.0 ; platform_python_implementation != 'PyPy'
|
| 13 |
+
charset-normalizer==3.4.4
|
| 14 |
+
click==8.3.1
|
| 15 |
+
colorama==0.4.6 ; sys_platform == 'win32'
|
| 16 |
+
cryptography==46.0.3
|
| 17 |
+
cuda-bindings==12.9.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 18 |
+
cuda-pathfinder==1.3.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 19 |
+
decorator==5.2.1
|
| 20 |
+
diffusers @ git+https://github.com/huggingface/diffusers@956bdcc3ea4897eaeb6c828b8433bdcae71e9f0f
|
| 21 |
+
einops==0.8.2
|
| 22 |
+
exceptiongroup==1.3.1 ; python_full_version < '3.11'
|
| 23 |
+
executing==2.2.1
|
| 24 |
+
fal-client==0.12.0
|
| 25 |
+
fastapi==0.128.0
|
| 26 |
+
ffmpy==1.0.0
|
| 27 |
+
filelock==3.20.3
|
| 28 |
+
fsspec==2026.1.0
|
| 29 |
+
gradio==6.4.0
|
| 30 |
+
gradio-client==2.0.3
|
| 31 |
+
groovy==0.1.2
|
| 32 |
+
h11==0.16.0
|
| 33 |
+
hf-xet==1.2.0 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
| 34 |
+
httpcore==1.0.9
|
| 35 |
+
httpx==0.28.1
|
| 36 |
+
httpx-sse==0.4.3
|
| 37 |
+
huggingface-hub==1.3.4
|
| 38 |
+
idna==3.11
|
| 39 |
+
importlib-metadata==8.7.1
|
| 40 |
+
ipython==8.38.0 ; python_full_version < '3.11'
|
| 41 |
+
ipython==9.9.0 ; python_full_version >= '3.11'
|
| 42 |
+
ipython-pygments-lexers==1.1.1 ; python_full_version >= '3.11'
|
| 43 |
+
jedi==0.19.2
|
| 44 |
+
jinja2==3.1.6
|
| 45 |
+
jmespath==1.0.1
|
| 46 |
+
jsonschema==4.26.0
|
| 47 |
+
jsonschema-specifications==2025.9.1
|
| 48 |
+
markdown-it-py==4.0.0
|
| 49 |
+
markupsafe==3.0.3
|
| 50 |
+
matplotlib-inline==0.2.1
|
| 51 |
+
mcp==1.26.0
|
| 52 |
+
mdurl==0.1.2
|
| 53 |
+
mpmath==1.3.0
|
| 54 |
+
msgpack==1.1.2
|
| 55 |
+
networkx==3.4.2 ; python_full_version < '3.11'
|
| 56 |
+
networkx==3.6.1 ; python_full_version >= '3.11'
|
| 57 |
+
numpy==2.2.6 ; python_full_version < '3.11'
|
| 58 |
+
numpy==2.4.1 ; python_full_version >= '3.11'
|
| 59 |
+
nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 60 |
+
nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 61 |
+
nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 62 |
+
nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 63 |
+
nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 64 |
+
nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 65 |
+
nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 66 |
+
nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 67 |
+
nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 68 |
+
nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 69 |
+
nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 70 |
+
nvidia-nccl-cu12==2.27.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 71 |
+
nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 72 |
+
nvidia-nvshmem-cu12==3.4.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 73 |
+
nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 74 |
+
orjson==3.11.5
|
| 75 |
+
packaging==26.0
|
| 76 |
+
pandas==2.3.3
|
| 77 |
+
parso==0.8.5
|
| 78 |
+
peft==0.18.1
|
| 79 |
+
pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
| 80 |
+
pillow==12.1.0
|
| 81 |
+
prompt-toolkit==3.0.52
|
| 82 |
+
psutil==5.9.8
|
| 83 |
+
ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
| 84 |
+
pure-eval==0.2.3
|
| 85 |
+
pycparser==3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
|
| 86 |
+
pydantic==2.12.5
|
| 87 |
+
pydantic-core==2.41.5
|
| 88 |
+
pydantic-settings==2.12.0
|
| 89 |
+
pydub==0.25.1
|
| 90 |
+
pygments==2.19.2
|
| 91 |
+
pyjwt==2.10.1
|
| 92 |
+
python-dateutil==2.9.0.post0
|
| 93 |
+
python-dotenv==1.2.1
|
| 94 |
+
python-multipart==0.0.22
|
| 95 |
+
pytz==2025.2
|
| 96 |
+
pywin32==311 ; sys_platform == 'win32'
|
| 97 |
+
pyyaml==6.0.3
|
| 98 |
+
referencing==0.37.0
|
| 99 |
+
regex==2026.1.15
|
| 100 |
+
requests==2.32.5
|
| 101 |
+
rich==14.3.1
|
| 102 |
+
rpds-py==0.30.0
|
| 103 |
+
s3transfer==0.16.0
|
| 104 |
+
safehttpx==0.1.7
|
| 105 |
+
safetensors==0.7.0
|
| 106 |
+
semantic-version==2.10.0
|
| 107 |
+
setuptools==80.10.2 ; python_full_version >= '3.12'
|
| 108 |
+
shellingham==1.5.4
|
| 109 |
+
six==1.17.0
|
| 110 |
+
spaces==0.47.0
|
| 111 |
+
sse-starlette==3.2.0
|
| 112 |
+
stack-data==0.6.3
|
| 113 |
+
starlette==0.50.0
|
| 114 |
+
sympy==1.14.0
|
| 115 |
+
tokenizers==0.22.2
|
| 116 |
+
tomlkit==0.13.3
|
| 117 |
+
torch==2.10.0
|
| 118 |
+
torchvision==0.25.0
|
| 119 |
+
tqdm==4.67.1
|
| 120 |
+
traitlets==5.14.3
|
| 121 |
+
transformers==5.0.0
|
| 122 |
+
triton==3.6.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 123 |
+
typer==0.21.1
|
| 124 |
+
typer-slim==0.21.1
|
| 125 |
+
typing-extensions==4.15.0
|
| 126 |
+
typing-inspection==0.4.2
|
| 127 |
+
tzdata==2025.3
|
| 128 |
+
ujson==5.11.0
|
| 129 |
+
urllib3==2.6.3
|
| 130 |
+
uvicorn==0.40.0
|
| 131 |
+
wcwidth==0.2.14
|
| 132 |
+
websockets==16.0
|
| 133 |
+
zipp==3.23.0
|
utils.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Camera angle data structures for Fibo Edit."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from enum import Enum
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class View(Enum):
|
| 8 |
+
"""Camera view angles"""
|
| 9 |
+
BACK_VIEW = "back view"
|
| 10 |
+
BACK_LEFT_QUARTER = "back-left quarter view"
|
| 11 |
+
BACK_RIGHT_QUARTER = "back-right quarter view"
|
| 12 |
+
FRONT_VIEW = "front view"
|
| 13 |
+
FRONT_LEFT_QUARTER = "front-left quarter view"
|
| 14 |
+
FRONT_RIGHT_QUARTER = "front-right quarter view"
|
| 15 |
+
LEFT_SIDE = "left side view"
|
| 16 |
+
RIGHT_SIDE = "right side view"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Shot(Enum):
|
| 20 |
+
"""
|
| 21 |
+
Camera shot angles (measured from horizontal/eye-level as 0 degrees)
|
| 22 |
+
|
| 23 |
+
- ELEVATED: 45-60 degrees above subject (moderately elevated)
|
| 24 |
+
- EYE_LEVEL: 0 degrees (horizontal with subject)
|
| 25 |
+
- HIGH_ANGLE: 60-90 degrees above subject (steep overhead, bird's eye)
|
| 26 |
+
- LOW_ANGLE: Below eye level (looking up at subject)
|
| 27 |
+
"""
|
| 28 |
+
ELEVATED = "elevated shot"
|
| 29 |
+
EYE_LEVEL = "eye-level shot"
|
| 30 |
+
HIGH_ANGLE = "high-angle shot"
|
| 31 |
+
LOW_ANGLE = "low-angle shot"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Zoom(Enum):
|
| 35 |
+
"""Camera zoom levels"""
|
| 36 |
+
CLOSE_UP = "close-up"
|
| 37 |
+
MEDIUM = "medium shot"
|
| 38 |
+
WIDE = "wide shot"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class AngleInstruction:
|
| 43 |
+
view: View
|
| 44 |
+
shot: Shot
|
| 45 |
+
zoom: Zoom
|
| 46 |
+
|
| 47 |
+
def __str__(self):
|
| 48 |
+
return f"<sks> {self.view.value} {self.shot.value} {self.zoom.value}"
|
| 49 |
+
|
| 50 |
+
@classmethod
|
| 51 |
+
def from_camera_params(cls, rotation: float, tilt: float, zoom: float) -> "AngleInstruction":
|
| 52 |
+
"""
|
| 53 |
+
Create an AngleInstruction from camera parameters.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
rotation: Horizontal rotation in degrees (-180 to 180)
|
| 57 |
+
-180/180: back view, -90: left view, 0: front view, 90: right view
|
| 58 |
+
tilt: Vertical tilt (-1 to 1)
|
| 59 |
+
-1 to -0.33: low-angle shot
|
| 60 |
+
-0.33 to 0.33: eye-level shot
|
| 61 |
+
0.33 to 0.66: elevated shot
|
| 62 |
+
0.66 to 1: high-angle shot
|
| 63 |
+
zoom: Zoom level (0 to 10)
|
| 64 |
+
0-3.33: wide shot
|
| 65 |
+
3.33-6.66: medium shot
|
| 66 |
+
6.66-10: close-up
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
AngleInstruction instance
|
| 70 |
+
"""
|
| 71 |
+
# Map rotation to View
|
| 72 |
+
# Normalize rotation to -180 to 180 range
|
| 73 |
+
rotation = rotation % 360
|
| 74 |
+
if rotation > 180:
|
| 75 |
+
rotation -= 360
|
| 76 |
+
|
| 77 |
+
# Determine view based on rotation
|
| 78 |
+
if -157.5 <= rotation < -112.5:
|
| 79 |
+
view = View.BACK_LEFT_QUARTER
|
| 80 |
+
elif -112.5 <= rotation < -67.5:
|
| 81 |
+
view = View.LEFT_SIDE
|
| 82 |
+
elif -67.5 <= rotation < -22.5:
|
| 83 |
+
view = View.FRONT_LEFT_QUARTER
|
| 84 |
+
elif -22.5 <= rotation < 22.5:
|
| 85 |
+
view = View.FRONT_VIEW
|
| 86 |
+
elif 22.5 <= rotation < 67.5:
|
| 87 |
+
view = View.FRONT_RIGHT_QUARTER
|
| 88 |
+
elif 67.5 <= rotation < 112.5:
|
| 89 |
+
view = View.RIGHT_SIDE
|
| 90 |
+
elif 112.5 <= rotation < 157.5:
|
| 91 |
+
view = View.BACK_RIGHT_QUARTER
|
| 92 |
+
else: # 157.5 to 180 or -180 to -157.5
|
| 93 |
+
view = View.BACK_VIEW
|
| 94 |
+
|
| 95 |
+
# Map tilt to Shot
|
| 96 |
+
if tilt < -0.33:
|
| 97 |
+
shot = Shot.LOW_ANGLE
|
| 98 |
+
elif tilt < 0.33:
|
| 99 |
+
shot = Shot.EYE_LEVEL
|
| 100 |
+
elif tilt < 0.66:
|
| 101 |
+
shot = Shot.ELEVATED
|
| 102 |
+
else:
|
| 103 |
+
shot = Shot.HIGH_ANGLE
|
| 104 |
+
|
| 105 |
+
# Map zoom to Zoom
|
| 106 |
+
if zoom < 3.33:
|
| 107 |
+
zoom_level = Zoom.WIDE
|
| 108 |
+
elif zoom < 6.66:
|
| 109 |
+
zoom_level = Zoom.MEDIUM
|
| 110 |
+
else:
|
| 111 |
+
zoom_level = Zoom.CLOSE_UP
|
| 112 |
+
|
| 113 |
+
return cls(view=view, shot=shot, zoom=zoom_level)
|