Upload app code and configuration files
Browse files- .gitattributes +1 -0
- README.md +14 -0
- app.py +1380 -0
- camera_control_ui.py +589 -0
- camera_control_ui.pyi +572 -0
- examples/1.jpg +0 -0
- examples/10.jpeg +0 -0
- examples/11.jpg +0 -0
- examples/12.jpg +0 -0
- examples/13.jpg +3 -0
- examples/14.jpg +0 -0
- examples/2.jpeg +0 -0
- examples/4.jpg +0 -0
- examples/5.jpg +0 -0
- examples/6.jpg +0 -0
- examples/7.jpg +0 -0
- examples/8.jpg +0 -0
- examples/9.jpg +0 -0
- examples/ELS.jpg +0 -0
- pre-requirements.txt +1 -0
- qwenimage/__init__.py +0 -0
- qwenimage/pipeline_qwenimage_edit_plus.py +900 -0
- qwenimage/qwen_fa3_processor.py +142 -0
- qwenimage/transformer_qwenimage.py +642 -0
- requirements.txt +14 -0
- setup_manager.py +462 -0
- start_app.sh +7 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/13.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: QIE-2511 Rapid-AIO LoRAs Fast (Experimental)
|
| 3 |
+
emoji: ⚡
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.2.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
short_description: Demo of the Collection of Qwen Image Edit LoRAs
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,1380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
try:
|
| 3 |
+
import spaces
|
| 4 |
+
except ImportError:
|
| 5 |
+
class spaces:
|
| 6 |
+
@staticmethod
|
| 7 |
+
def GPU(f): return f
|
| 8 |
+
sys.modules["spaces"] = sys.modules.get("spaces", spaces)
|
| 9 |
+
import os
|
| 10 |
+
from camera_control_ui import CameraControl3D, build_camera_prompt, update_prompt_with_camera
|
| 11 |
+
import re
|
| 12 |
+
import gc
|
| 13 |
+
import traceback
|
| 14 |
+
import gradio as gr
|
| 15 |
+
import numpy as np
|
| 16 |
+
import spaces
|
| 17 |
+
import torch
|
| 18 |
+
import random
|
| 19 |
+
from PIL import Image, ImageDraw
|
| 20 |
+
from typing import Iterable, Optional
|
| 21 |
+
|
| 22 |
+
from transformers import (
|
| 23 |
+
AutoImageProcessor,
|
| 24 |
+
AutoModelForDepthEstimation,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
from huggingface_hub import hf_hub_download
|
| 28 |
+
from safetensors.torch import load_file as safetensors_load_file
|
| 29 |
+
|
| 30 |
+
from gradio.themes import Soft
|
| 31 |
+
from gradio.themes.utils import colors, fonts, sizes
|
| 32 |
+
|
| 33 |
+
# ============================================================
|
| 34 |
+
# Theme
|
| 35 |
+
# ============================================================
|
| 36 |
+
|
| 37 |
+
colors.orange_red = colors.Color(
|
| 38 |
+
name="orange_red",
|
| 39 |
+
c50="#FFF0E5",
|
| 40 |
+
c100="#FFE0CC",
|
| 41 |
+
c200="#FFC299",
|
| 42 |
+
c300="#FFA366",
|
| 43 |
+
c400="#FF8533",
|
| 44 |
+
c500="#FF4500",
|
| 45 |
+
c600="#E63E00",
|
| 46 |
+
c700="#CC3700",
|
| 47 |
+
c800="#B33000",
|
| 48 |
+
c900="#992900",
|
| 49 |
+
c950="#802200",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class OrangeRedTheme(Soft):
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
*,
|
| 57 |
+
primary_hue: colors.Color | str = colors.gray,
|
| 58 |
+
secondary_hue: colors.Color | str = colors.orange_red,
|
| 59 |
+
neutral_hue: colors.Color | str = colors.slate,
|
| 60 |
+
text_size: sizes.Size | str = sizes.text_lg,
|
| 61 |
+
font: fonts.Font | str | Iterable[fonts.Font | str] = (
|
| 62 |
+
fonts.GoogleFont("Outfit"),
|
| 63 |
+
"Arial",
|
| 64 |
+
"sans-serif",
|
| 65 |
+
),
|
| 66 |
+
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
|
| 67 |
+
fonts.GoogleFont("IBM Plex Mono"),
|
| 68 |
+
"ui-monospace",
|
| 69 |
+
"monospace",
|
| 70 |
+
),
|
| 71 |
+
):
|
| 72 |
+
super().__init__(
|
| 73 |
+
primary_hue=primary_hue,
|
| 74 |
+
secondary_hue=secondary_hue,
|
| 75 |
+
neutral_hue=neutral_hue,
|
| 76 |
+
text_size=text_size,
|
| 77 |
+
font=font,
|
| 78 |
+
font_mono=font_mono,
|
| 79 |
+
)
|
| 80 |
+
super().set(
|
| 81 |
+
background_fill_primary="*primary_50",
|
| 82 |
+
background_fill_primary_dark="*primary_900",
|
| 83 |
+
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
|
| 84 |
+
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
|
| 85 |
+
button_primary_text_color="white",
|
| 86 |
+
button_primary_text_color_hover="white",
|
| 87 |
+
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
|
| 88 |
+
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
|
| 89 |
+
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
|
| 90 |
+
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
|
| 91 |
+
button_secondary_text_color="black",
|
| 92 |
+
button_secondary_text_color_hover="white",
|
| 93 |
+
button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
|
| 94 |
+
button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
|
| 95 |
+
button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
|
| 96 |
+
button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
|
| 97 |
+
slider_color="*secondary_500",
|
| 98 |
+
slider_color_dark="*secondary_600",
|
| 99 |
+
block_title_text_weight="600",
|
| 100 |
+
block_border_width="3px",
|
| 101 |
+
block_shadow="*shadow_drop_lg",
|
| 102 |
+
button_primary_shadow="*shadow_drop_lg",
|
| 103 |
+
button_large_padding="11px",
|
| 104 |
+
color_accent_soft="*primary_100",
|
| 105 |
+
block_label_background_fill="*primary_200",
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
orange_red_theme = OrangeRedTheme()
|
| 110 |
+
|
| 111 |
+
# ============================================================
|
| 112 |
+
# Device
|
| 113 |
+
# ============================================================
|
| 114 |
+
|
| 115 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 116 |
+
|
| 117 |
+
print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
|
| 118 |
+
print("torch.__version__ =", torch.__version__)
|
| 119 |
+
print("torch.version.cuda =", torch.version.cuda)
|
| 120 |
+
print("cuda available:", torch.cuda.is_available())
|
| 121 |
+
print("cuda device count:", torch.cuda.device_count())
|
| 122 |
+
if torch.cuda.is_available():
|
| 123 |
+
print("current device:", torch.cuda.current_device())
|
| 124 |
+
print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
|
| 125 |
+
print("Using device:", device)
|
| 126 |
+
|
| 127 |
+
# ============================================================
|
| 128 |
+
# AIO version (Space variable)
|
| 129 |
+
# ============================================================
|
| 130 |
+
|
| 131 |
+
AIO_REPO_ID = "Pr0f3ssi0n4ln00b/Phr00t-Qwen-Rapid-AIO"
|
| 132 |
+
DEFAULT_AIO_VERSION = "v19"
|
| 133 |
+
|
| 134 |
+
_VER_RE = re.compile(r"^v\d+$")
|
| 135 |
+
_DIGITS_RE = re.compile(r"^\d+$")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _normalize_version(raw: str) -> Optional[str]:
|
| 139 |
+
if raw is None:
|
| 140 |
+
return None
|
| 141 |
+
s = str(raw).strip()
|
| 142 |
+
if not s:
|
| 143 |
+
return None
|
| 144 |
+
if _VER_RE.fullmatch(s):
|
| 145 |
+
return s
|
| 146 |
+
# forgiving: allow "21" -> "v21"
|
| 147 |
+
if _DIGITS_RE.fullmatch(s):
|
| 148 |
+
return f"v{s}"
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
_AIO_ENV_RAW = os.environ.get("AIO_VERSION", "")
|
| 153 |
+
_AIO_ENV_NORM = _normalize_version(_AIO_ENV_RAW)
|
| 154 |
+
|
| 155 |
+
AIO_VERSION = _AIO_ENV_NORM or DEFAULT_AIO_VERSION
|
| 156 |
+
AIO_VERSION_SOURCE = "env" if _AIO_ENV_NORM else "default(v19)"
|
| 157 |
+
|
| 158 |
+
print(f"AIO_VERSION (env raw) = {_AIO_ENV_RAW!r}")
|
| 159 |
+
print(f"AIO_VERSION (normalized) = {_AIO_ENV_NORM!r}")
|
| 160 |
+
print(f"Using AIO_VERSION = {AIO_VERSION} ({AIO_VERSION_SOURCE})")
|
| 161 |
+
|
| 162 |
+
# ============================================================
|
| 163 |
+
# Pipeline
|
| 164 |
+
# ============================================================
|
| 165 |
+
|
| 166 |
+
from diffusers import FlowMatchEulerDiscreteScheduler # noqa: F401
|
| 167 |
+
from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
| 168 |
+
from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
|
| 169 |
+
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
|
| 170 |
+
|
| 171 |
+
dtype = torch.bfloat16
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _load_pipe_with_version(version: str) -> QwenImageEditPlusPipeline:
|
| 175 |
+
sub = f"{version}/transformer"
|
| 176 |
+
print(f"📦 Loading AIO transformer: {AIO_REPO_ID} / {sub}")
|
| 177 |
+
p = QwenImageEditPlusPipeline.from_pretrained(
|
| 178 |
+
"Qwen/Qwen-Image-Edit-2511",
|
| 179 |
+
transformer=QwenImageTransformer2DModel.from_pretrained(
|
| 180 |
+
AIO_REPO_ID,
|
| 181 |
+
subfolder=sub,
|
| 182 |
+
torch_dtype=dtype,
|
| 183 |
+
device_map="auto",
|
| 184 |
+
low_cpu_mem_usage=True,
|
| 185 |
+
),
|
| 186 |
+
torch_dtype=dtype,
|
| 187 |
+
)
|
| 188 |
+
p.enable_model_cpu_offload()
|
| 189 |
+
return p
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# Forgiving load: try env/default version, fallback to v19 if it fails
|
| 193 |
+
try:
|
| 194 |
+
pipe = _load_pipe_with_version(AIO_VERSION)
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print("❌ Failed to load requested AIO_VERSION. Falling back to v19.")
|
| 197 |
+
print("---- exception ----")
|
| 198 |
+
print(traceback.format_exc())
|
| 199 |
+
print("-------------------")
|
| 200 |
+
AIO_VERSION = DEFAULT_AIO_VERSION
|
| 201 |
+
AIO_VERSION_SOURCE = "fallback_to_v19"
|
| 202 |
+
pipe = _load_pipe_with_version(AIO_VERSION)
|
| 203 |
+
|
| 204 |
+
# Apply FA3 Optimization
|
| 205 |
+
try:
|
| 206 |
+
print("Skipping FA3 optimization for stability.")
|
| 207 |
+
print("Flash Attention 3 Processor set successfully.")
|
| 208 |
+
except Exception as e:
|
| 209 |
+
print(f"Warning: Could not set FA3 processor: {e}")
|
| 210 |
+
|
| 211 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 212 |
+
|
| 213 |
+
# ============================================================
|
| 214 |
+
# VAE tiling toggle (UI-controlled; OFF by default)
|
| 215 |
+
# ============================================================
|
| 216 |
+
|
| 217 |
+
def _apply_vae_tiling(enabled: bool):
|
| 218 |
+
"""
|
| 219 |
+
Toggle VAE tiling on the global pipeline.
|
| 220 |
+
|
| 221 |
+
This does NOT require a Space restart; it applies to the next pipe(...) call.
|
| 222 |
+
Note: this is global process state, so concurrent users could flip it between runs.
|
| 223 |
+
"""
|
| 224 |
+
try:
|
| 225 |
+
if enabled:
|
| 226 |
+
if hasattr(pipe, "enable_vae_tiling"):
|
| 227 |
+
pipe.enable_vae_tiling()
|
| 228 |
+
print("✅ VAE tiling ENABLED (per UI).")
|
| 229 |
+
elif hasattr(pipe, "vae") and hasattr(pipe.vae, "enable_tiling"):
|
| 230 |
+
pipe.vae.enable_tiling()
|
| 231 |
+
print("✅ VAE tiling ENABLED via pipe.vae.enable_tiling() (per UI).")
|
| 232 |
+
else:
|
| 233 |
+
print("⚠️ No enable_vae_tiling()/vae.enable_tiling() found; cannot enable.")
|
| 234 |
+
else:
|
| 235 |
+
if hasattr(pipe, "disable_vae_tiling"):
|
| 236 |
+
pipe.disable_vae_tiling()
|
| 237 |
+
print("🛑 VAE tiling DISABLED (per UI).")
|
| 238 |
+
elif hasattr(pipe, "vae") and hasattr(pipe.vae, "disable_tiling"):
|
| 239 |
+
pipe.vae.disable_tiling()
|
| 240 |
+
print("🛑 VAE tiling DISABLED via pipe.vae.disable_tiling() (per UI).")
|
| 241 |
+
else:
|
| 242 |
+
# If no disable method exists, we leave current state unchanged.
|
| 243 |
+
print("⚠️ No disable_vae_tiling()/vae.disable_tiling() found; leaving current state unchanged.")
|
| 244 |
+
except Exception as e:
|
| 245 |
+
print(f"⚠️ VAE tiling toggle failed: {e}")
|
| 246 |
+
|
| 247 |
+
# ============================================================
|
| 248 |
+
# Derived conditioning (Transformers): Depth
|
| 249 |
+
# ============================================================
|
| 250 |
+
# Depth uses Depth Anything V2 Small (Transformers-compatible):
|
| 251 |
+
# https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf
|
| 252 |
+
|
| 253 |
+
DEPTH_MODEL_ID = "depth-anything/Depth-Anything-V2-Small-hf"
|
| 254 |
+
|
| 255 |
+
# Lazy cache keyed by device string ("cpu" / "cuda")
|
| 256 |
+
_DEPTH_CACHE = {}
|
| 257 |
+
|
| 258 |
+
def _derived_device(use_gpu: bool) -> torch.device:
|
| 259 |
+
return torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")
|
| 260 |
+
|
| 261 |
+
def _load_depth_models(dev: torch.device):
|
| 262 |
+
key = str(dev)
|
| 263 |
+
if key in _DEPTH_CACHE:
|
| 264 |
+
return _DEPTH_CACHE[key]
|
| 265 |
+
proc = AutoImageProcessor.from_pretrained(DEPTH_MODEL_ID)
|
| 266 |
+
model = AutoModelForDepthEstimation.from_pretrained(DEPTH_MODEL_ID).to(dev)
|
| 267 |
+
model.eval()
|
| 268 |
+
_DEPTH_CACHE[key] = (proc, model)
|
| 269 |
+
return _DEPTH_CACHE[key]
|
| 270 |
+
|
| 271 |
+
@torch.inference_mode()
|
| 272 |
+
def make_depth_map(img: Image.Image, *, use_gpu: bool) -> Image.Image:
|
| 273 |
+
dev = _derived_device(use_gpu)
|
| 274 |
+
proc, model = _load_depth_models(dev)
|
| 275 |
+
|
| 276 |
+
w, h = img.size
|
| 277 |
+
inputs = proc(images=img.convert("RGB"), return_tensors="pt").to(dev)
|
| 278 |
+
outputs = model(**inputs)
|
| 279 |
+
predicted = outputs.predicted_depth # [B, H, W]
|
| 280 |
+
|
| 281 |
+
depth = torch.nn.functional.interpolate(
|
| 282 |
+
predicted.unsqueeze(1),
|
| 283 |
+
size=(h, w),
|
| 284 |
+
mode="bicubic",
|
| 285 |
+
align_corners=False,
|
| 286 |
+
).squeeze(1)[0]
|
| 287 |
+
|
| 288 |
+
depth = depth - depth.min()
|
| 289 |
+
depth = depth / (depth.max() + 1e-8)
|
| 290 |
+
depth = (depth * 255.0).clamp(0, 255).to(torch.uint8).cpu().numpy()
|
| 291 |
+
return Image.fromarray(depth).convert("RGB")
|
| 292 |
+
|
| 293 |
+
# ============================================================
|
| 294 |
+
# LoRA adapters + presets
|
| 295 |
+
# ============================================================
|
| 296 |
+
|
| 297 |
+
NONE_LORA = "None"
|
| 298 |
+
|
| 299 |
+
ADAPTER_SPECS = {
|
| 300 |
+
"3D-Camera": {
|
| 301 |
+
"type": "single",
|
| 302 |
+
"repo": "fal/Qwen-Image-Edit-2511-Multiple-Angles-LoRA",
|
| 303 |
+
"weights": "qwen-image-edit-2511-multiple-angles-lora.safetensors",
|
| 304 |
+
"adapter_name": "angles",
|
| 305 |
+
"strength": 1.0,
|
| 306 |
+
},
|
| 307 |
+
|
| 308 |
+
"Qwen-lora-nsfw": {
|
| 309 |
+
"type": "single",
|
| 310 |
+
"repo": "wiikoo/Qwen-lora-nsfw",
|
| 311 |
+
"weights": "loras/qwen_image_edit_remove-clothing_v1.0.safetensors",
|
| 312 |
+
"adapter_name": "qwen-lora-nsfw",
|
| 313 |
+
"strength": 1.0,
|
| 314 |
+
},
|
| 315 |
+
|
| 316 |
+
"Consistance": {
|
| 317 |
+
"type": "single",
|
| 318 |
+
"repo": "Pr0f3ssi0n4ln00b/QIE_2511_Consistency_Lora",
|
| 319 |
+
"weights": "qe2511_consis_alpha_patched.safetensors",
|
| 320 |
+
"adapter_name": "Consistency",
|
| 321 |
+
"strength": 0.6,
|
| 322 |
+
},
|
| 323 |
+
"Semirealistic-photo-detailer": {
|
| 324 |
+
"type": "single",
|
| 325 |
+
"repo": "rzgar/Qwen-Image-Edit-semi-realistic-detailer",
|
| 326 |
+
"weights": "Qwen-Image-Edit-Anime-Semi-Realistic-Detailer-v1.safetensors",
|
| 327 |
+
"adapter_name": "semirealistic",
|
| 328 |
+
"strength": 1.0,
|
| 329 |
+
},
|
| 330 |
+
"AnyPose": {
|
| 331 |
+
"type": "package",
|
| 332 |
+
"requires_two_images": True,
|
| 333 |
+
"image2_label": "Upload Pose Reference (Image 2)",
|
| 334 |
+
"parts": [
|
| 335 |
+
{
|
| 336 |
+
"repo": "lilylilith/AnyPose",
|
| 337 |
+
"weights": "2511-AnyPose-base-000006250.safetensors",
|
| 338 |
+
"adapter_name": "anypose-base",
|
| 339 |
+
"strength": 0.7,
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
"repo": "lilylilith/AnyPose",
|
| 343 |
+
"weights": "2511-AnyPose-helper-00006000.safetensors",
|
| 344 |
+
"adapter_name": "anypose-helper",
|
| 345 |
+
"strength": 0.7,
|
| 346 |
+
},
|
| 347 |
+
],
|
| 348 |
+
},
|
| 349 |
+
"Any2Real_2601": {
|
| 350 |
+
"type": "single",
|
| 351 |
+
"repo": "lrzjason/Anything2Real_2601",
|
| 352 |
+
"weights": "anything2real_2601_A_final_patched.safetensors",
|
| 353 |
+
"adapter_name": "photoreal",
|
| 354 |
+
"strength": 1.0,
|
| 355 |
+
},
|
| 356 |
+
"Hyperrealistic-Portrait": {
|
| 357 |
+
"type": "single",
|
| 358 |
+
"repo": "prithivMLmods/Qwen-Image-Edit-2511-Hyper-Realistic-Portrait",
|
| 359 |
+
"weights": "HRP_20.safetensors",
|
| 360 |
+
"adapter_name": "HRPortrait",
|
| 361 |
+
"strength": 1.0,
|
| 362 |
+
},
|
| 363 |
+
"Ultrarealistic-Portrait": {
|
| 364 |
+
"type": "single",
|
| 365 |
+
"repo": "prithivMLmods/Qwen-Image-Edit-2511-Ultra-Realistic-Portrait",
|
| 366 |
+
"weights": "URP_20.safetensors",
|
| 367 |
+
"adapter_name": "URPortrait",
|
| 368 |
+
"strength": 1.0,
|
| 369 |
+
},
|
| 370 |
+
"BFS-Best-FaceSwap": {
|
| 371 |
+
"type": "single",
|
| 372 |
+
"requires_two_images": True,
|
| 373 |
+
"image2_label": "Upload Head/Face Donor (Image 2)",
|
| 374 |
+
"repo": "Alissonerdx/BFS-Best-Face-Swap",
|
| 375 |
+
"weights": "bfs_head_v5_2511_original.safetensors",
|
| 376 |
+
"adapter_name": "BFS-Best-Faceswap",
|
| 377 |
+
"strength": 1.0,
|
| 378 |
+
"needs_alpha_fix": True, # <-- fixes KeyError 'img_in.alpha'
|
| 379 |
+
},
|
| 380 |
+
"BFS-Best-FaceSwap-merge": {
|
| 381 |
+
"type": "single",
|
| 382 |
+
"requires_two_images": True,
|
| 383 |
+
"image2_label": "Upload Head/Face Donor (Image 2)",
|
| 384 |
+
"repo": "Alissonerdx/BFS-Best-Face-Swap",
|
| 385 |
+
"weights": "bfs_head_v5_2511_merged_version_rank_32_fp32.safetensors",
|
| 386 |
+
"adapter_name": "BFS-Best-Faceswap-merge",
|
| 387 |
+
"strength": 1.1,
|
| 388 |
+
"needs_alpha_fix": True, # <-- fixes KeyError 'img_in.alpha'
|
| 389 |
+
},
|
| 390 |
+
"F2P": {
|
| 391 |
+
"type": "single",
|
| 392 |
+
"repo": "DiffSynth-Studio/Qwen-Image-Edit-F2P",
|
| 393 |
+
"weights": "edit_0928_lora_step40000.safetensors",
|
| 394 |
+
"adapter_name": "F2P",
|
| 395 |
+
"strength": 1.0,
|
| 396 |
+
},
|
| 397 |
+
"Multiple-Angles": {
|
| 398 |
+
"type": "single",
|
| 399 |
+
"repo": "dx8152/Qwen-Edit-2509-Multiple-angles",
|
| 400 |
+
"weights": "镜头转换.safetensors",
|
| 401 |
+
"adapter_name": "multiple-angles",
|
| 402 |
+
"strength": 1.0,
|
| 403 |
+
},
|
| 404 |
+
"Light-Restoration": {
|
| 405 |
+
"type": "single",
|
| 406 |
+
"repo": "dx8152/Qwen-Image-Edit-2509-Light_restoration",
|
| 407 |
+
"weights": "移除光影.safetensors",
|
| 408 |
+
"adapter_name": "light-restoration",
|
| 409 |
+
"strength": 1.0,
|
| 410 |
+
},
|
| 411 |
+
"Relight": {
|
| 412 |
+
"type": "single",
|
| 413 |
+
"repo": "dx8152/Qwen-Image-Edit-2509-Relight",
|
| 414 |
+
"weights": "Qwen-Edit-Relight.safetensors",
|
| 415 |
+
"adapter_name": "relight",
|
| 416 |
+
"strength": 1.0,
|
| 417 |
+
},
|
| 418 |
+
"Multi-Angle-Lighting": {
|
| 419 |
+
"type": "single",
|
| 420 |
+
"repo": "dx8152/Qwen-Edit-2509-Multi-Angle-Lighting",
|
| 421 |
+
"weights": "多角度灯光-251116.safetensors",
|
| 422 |
+
"adapter_name": "multi-angle-lighting",
|
| 423 |
+
"strength": 1.0,
|
| 424 |
+
},
|
| 425 |
+
"Edit-Skin": {
|
| 426 |
+
"type": "single",
|
| 427 |
+
"repo": "tlennon-ie/qwen-edit-skin",
|
| 428 |
+
"weights": "qwen-edit-skin_1.1_000002750.safetensors",
|
| 429 |
+
"adapter_name": "edit-skin",
|
| 430 |
+
"strength": 1.0,
|
| 431 |
+
},
|
| 432 |
+
"Next-Scene": {
|
| 433 |
+
"type": "single",
|
| 434 |
+
"repo": "lovis93/next-scene-qwen-image-lora-2509",
|
| 435 |
+
"weights": "next-scene_lora-v2-3000.safetensors",
|
| 436 |
+
"adapter_name": "next-scene",
|
| 437 |
+
"strength": 1.0,
|
| 438 |
+
},
|
| 439 |
+
"Flat-Log": {
|
| 440 |
+
"type": "single",
|
| 441 |
+
"repo": "tlennon-ie/QwenEdit2509-FlatLogColor",
|
| 442 |
+
"weights": "QwenEdit2509-FlatLogColor.safetensors",
|
| 443 |
+
"adapter_name": "flat-log",
|
| 444 |
+
"strength": 1.0,
|
| 445 |
+
},
|
| 446 |
+
"Upscale-Image": {
|
| 447 |
+
"type": "single",
|
| 448 |
+
"repo": "vafipas663/Qwen-Edit-2509-Upscale-LoRA",
|
| 449 |
+
"weights": "qwen-edit-enhance_64-v3_000001000.safetensors",
|
| 450 |
+
"adapter_name": "upscale-image",
|
| 451 |
+
"strength": 1.0,
|
| 452 |
+
},
|
| 453 |
+
"Upscale2K": {
|
| 454 |
+
"type": "single",
|
| 455 |
+
"repo": "valiantcat/Qwen-Image-Edit-2509-Upscale2K",
|
| 456 |
+
"weights": "qwen_image_edit_2509_upscale.safetensors",
|
| 457 |
+
"adapter_name": "upscale-2k",
|
| 458 |
+
"strength": 1.0,
|
| 459 |
+
"target_long_edge": 2048,
|
| 460 |
+
},
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
LORA_PRESET_PROMPTS = {
|
| 464 |
+
"Any2Real_2601": "change the picture 1 to realistic photograph",
|
| 465 |
+
"Semirealistic-photo-detailer": "transform the image to semi-realistic image",
|
| 466 |
+
"AnyPose": "Make the person in image 1 do the exact same pose of the person in image 2. Changing the style and background of the image of the person in image 1 is undesirable, so don't do it. The new pose should be pixel accurate to the pose we are trying to copy. The position of the arms and head and legs should be the same as the pose we are trying to copy. Change the field of view and angle to match exactly image 2. Head tilt and eye gaze pose should match the person in image 2.",
|
| 467 |
+
"Hyperrealistic-Portrait": "Transform the image into an ultra-realistic photorealistic portrait with strict identity preservation, facing straight to the camera. Enhance pore-level skin textures, realistic moisture effects, and natural wet hair clumping against the skin. Apply cool-toned soft-box lighting with subtle highlights and shadows, maintain realistic green-hazel eye catchlights without synthetic gloss, and preserve soft natural lip texture. Use shallow depth of field with a clean bokeh background, an 85mm macro photographic look, and raw photo grading without retouching to maintain realism and original details.",
|
| 468 |
+
"Ultrarealistic-Portrait": "Transform the image into an ultra-realistic glamour portrait while strictly preserving the subject’s identity. Apply a close-up composition with a slight head tilt and a hand near the face, enhance cinematic directional lighting with dramatic fashion-style highlights, and refine makeup details including glowing skin, glossy lips, luminous highlighter, and defined eyes. Increase skin realism with detailed epidermal textures such as micropores, microhairs, subtle oil sheen, natural highlights, soft wrinkles, and subsurface scattering. Maintain a luxury fashion-magazine look in a 9:16 aspect ratio, preserving realism, facial structure, and original details without over-smoothing or retouching.",
|
| 469 |
+
"Upscale2K": "Upscale this picture to 4K resolution.",
|
| 470 |
+
"BFS-Best-FaceSwap": "head_swap: start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
|
| 471 |
+
"BFS-Best-FaceSwap-merge": "head_swap: start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
# Track what is currently loaded in memory (adapter_name values)
|
| 475 |
+
LOADED_ADAPTERS = set()
|
| 476 |
+
|
| 477 |
+
# ============================================================
|
| 478 |
+
# Helpers: resolution
|
| 479 |
+
# ============================================================
|
| 480 |
+
|
| 481 |
+
# We prefer *area-based* sizing (≈ megapixels) over long-edge sizing.
|
| 482 |
+
# This aligns better with Qwen-Image-Edit's internal assumptions and reduces FOV drift.
|
| 483 |
+
|
| 484 |
+
def _round_to_multiple(x: int, m: int) -> int:
|
| 485 |
+
return max(m, (int(x) // m) * m)
|
| 486 |
+
|
| 487 |
+
def compute_canvas_dimensions_from_area(
|
| 488 |
+
image: Image.Image,
|
| 489 |
+
target_area: int,
|
| 490 |
+
multiple_of: int,
|
| 491 |
+
) -> tuple[int, int]:
|
| 492 |
+
"""Compute (width, height) that matches image aspect ratio and approximates target_area.
|
| 493 |
+
|
| 494 |
+
The result is floored to be divisible by multiple_of (typically vae_scale_factor*2).
|
| 495 |
+
"""
|
| 496 |
+
w, h = image.size
|
| 497 |
+
aspect = w / h if h else 1.0
|
| 498 |
+
|
| 499 |
+
# Use the pipeline's own area->(w,h) helper for consistency.
|
| 500 |
+
from qwenimage.pipeline_qwenimage_edit_plus import calculate_dimensions
|
| 501 |
+
|
| 502 |
+
width, height = calculate_dimensions(int(target_area), float(aspect))
|
| 503 |
+
width = _round_to_multiple(int(width), int(multiple_of))
|
| 504 |
+
height = _round_to_multiple(int(height), int(multiple_of))
|
| 505 |
+
return width, height
|
| 506 |
+
|
| 507 |
+
def get_target_area_for_lora(
|
| 508 |
+
image: Image.Image,
|
| 509 |
+
lora_adapter: str,
|
| 510 |
+
user_target_megapixels: float,
|
| 511 |
+
) -> int:
|
| 512 |
+
"""Return target pixel area for the canvas.
|
| 513 |
+
|
| 514 |
+
Priority:
|
| 515 |
+
1) Adapter spec: target_area (pixels) or target_megapixels
|
| 516 |
+
2) Adapter spec: target_long_edge (legacy) -> converted to area using image aspect
|
| 517 |
+
3) User slider target megapixels
|
| 518 |
+
"""
|
| 519 |
+
spec = ADAPTER_SPECS.get(lora_adapter, {})
|
| 520 |
+
|
| 521 |
+
if "target_area" in spec:
|
| 522 |
+
try:
|
| 523 |
+
return int(spec["target_area"])
|
| 524 |
+
except Exception:
|
| 525 |
+
pass
|
| 526 |
+
|
| 527 |
+
if "target_megapixels" in spec:
|
| 528 |
+
try:
|
| 529 |
+
mp = float(spec["target_megapixels"])
|
| 530 |
+
return int(mp * 1024 * 1024)
|
| 531 |
+
except Exception:
|
| 532 |
+
pass
|
| 533 |
+
|
| 534 |
+
# Legacy support (e.g. Upscale2K)
|
| 535 |
+
if "target_long_edge" in spec:
|
| 536 |
+
try:
|
| 537 |
+
long_edge = int(spec["target_long_edge"])
|
| 538 |
+
w, h = image.size
|
| 539 |
+
if w >= h:
|
| 540 |
+
new_w = long_edge
|
| 541 |
+
new_h = int(round(long_edge * (h / w)))
|
| 542 |
+
else:
|
| 543 |
+
new_h = long_edge
|
| 544 |
+
new_w = int(round(long_edge * (w / h)))
|
| 545 |
+
return int(new_w * new_h)
|
| 546 |
+
except Exception:
|
| 547 |
+
pass
|
| 548 |
+
|
| 549 |
+
# User default
|
| 550 |
+
try:
|
| 551 |
+
mp = float(user_target_megapixels)
|
| 552 |
+
except Exception:
|
| 553 |
+
mp = 1.0
|
| 554 |
+
|
| 555 |
+
# Treat 0 MP as "match input area"
|
| 556 |
+
if mp <= 0:
|
| 557 |
+
w, h = image.size
|
| 558 |
+
return int(w * h)
|
| 559 |
+
|
| 560 |
+
return int(mp * 1024 * 1024)
|
| 561 |
+
|
| 562 |
+
# ============================================================
|
| 563 |
+
# Helpers: multi-input routing + gallery normalization
|
| 564 |
+
# ============================================================
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def lora_requires_two_images(lora_adapter: str) -> bool:
|
| 568 |
+
return bool(ADAPTER_SPECS.get(lora_adapter, {}).get("requires_two_images", False))
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def image2_label_for_lora(lora_adapter: str) -> str:
|
| 572 |
+
return str(ADAPTER_SPECS.get(lora_adapter, {}).get("image2_label", "Upload Reference (Image 2)"))
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def _to_pil_rgb(x) -> Optional[Image.Image]:
|
| 576 |
+
"""
|
| 577 |
+
Accepts PIL / numpy / (image, caption) tuples from gr.Gallery and returns PIL RGB.
|
| 578 |
+
Gradio Gallery commonly yields tuples like (image, caption).
|
| 579 |
+
"""
|
| 580 |
+
if x is None:
|
| 581 |
+
return None
|
| 582 |
+
|
| 583 |
+
# Gallery often returns (image, caption)
|
| 584 |
+
if isinstance(x, tuple) and len(x) >= 1:
|
| 585 |
+
x = x[0]
|
| 586 |
+
if x is None:
|
| 587 |
+
return None
|
| 588 |
+
|
| 589 |
+
if isinstance(x, Image.Image):
|
| 590 |
+
return x.convert("RGB")
|
| 591 |
+
|
| 592 |
+
if isinstance(x, np.ndarray):
|
| 593 |
+
return Image.fromarray(x).convert("RGB")
|
| 594 |
+
|
| 595 |
+
# Best-effort fallback
|
| 596 |
+
try:
|
| 597 |
+
return Image.fromarray(np.array(x)).convert("RGB")
|
| 598 |
+
except Exception:
|
| 599 |
+
return None
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def build_labeled_images(
|
| 603 |
+
img1: Image.Image,
|
| 604 |
+
img2: Optional[Image.Image],
|
| 605 |
+
extra_imgs: Optional[list[Image.Image]],
|
| 606 |
+
) -> dict[str, Image.Image]:
|
| 607 |
+
"""
|
| 608 |
+
Creates labels image_1, image_2, image_3... based on what is actually uploaded:
|
| 609 |
+
- img1 is always image_1
|
| 610 |
+
- img2 becomes image_2 only if present
|
| 611 |
+
- extras start immediately after the last present base box
|
| 612 |
+
The pipeline receives images in this exact order.
|
| 613 |
+
"""
|
| 614 |
+
labeled: dict[str, Image.Image] = {}
|
| 615 |
+
idx = 1
|
| 616 |
+
|
| 617 |
+
labeled[f"image_{idx}"] = img1
|
| 618 |
+
idx += 1
|
| 619 |
+
|
| 620 |
+
if img2 is not None:
|
| 621 |
+
labeled[f"image_{idx}"] = img2
|
| 622 |
+
idx += 1
|
| 623 |
+
|
| 624 |
+
if extra_imgs:
|
| 625 |
+
for im in extra_imgs:
|
| 626 |
+
if im is None:
|
| 627 |
+
continue
|
| 628 |
+
labeled[f"image_{idx}"] = im
|
| 629 |
+
idx += 1
|
| 630 |
+
|
| 631 |
+
return labeled
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
# ============================================================
|
| 635 |
+
# Helpers: BFS alpha key fix
|
| 636 |
+
# ============================================================
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def _inject_missing_alpha_keys(state_dict: dict) -> dict:
|
| 640 |
+
"""
|
| 641 |
+
Diffusers' Qwen LoRA converter expects '<module>.alpha' keys.
|
| 642 |
+
BFS safetensors omits them. We inject alpha = rank (neutral scaling).
|
| 643 |
+
|
| 644 |
+
IMPORTANT: diffusers may strip 'diffusion_model.' before lookup, so we
|
| 645 |
+
inject BOTH:
|
| 646 |
+
- diffusion_model.xxx.alpha
|
| 647 |
+
- xxx.alpha
|
| 648 |
+
"""
|
| 649 |
+
bases = {}
|
| 650 |
+
|
| 651 |
+
for k, v in state_dict.items():
|
| 652 |
+
if not isinstance(v, torch.Tensor):
|
| 653 |
+
continue
|
| 654 |
+
if k.endswith(".lora_down.weight") and v.ndim >= 1:
|
| 655 |
+
base = k[: -len(".lora_down.weight")]
|
| 656 |
+
rank = int(v.shape[0])
|
| 657 |
+
bases[base] = rank
|
| 658 |
+
|
| 659 |
+
for base, rank in bases.items():
|
| 660 |
+
alpha_tensor = torch.tensor(float(rank), dtype=torch.float32)
|
| 661 |
+
|
| 662 |
+
full_alpha = f"{base}.alpha"
|
| 663 |
+
if full_alpha not in state_dict:
|
| 664 |
+
state_dict[full_alpha] = alpha_tensor
|
| 665 |
+
|
| 666 |
+
if base.startswith("diffusion_model."):
|
| 667 |
+
stripped_base = base[len("diffusion_model.") :]
|
| 668 |
+
stripped_alpha = f"{stripped_base}.alpha"
|
| 669 |
+
if stripped_alpha not in state_dict:
|
| 670 |
+
state_dict[stripped_alpha] = alpha_tensor
|
| 671 |
+
|
| 672 |
+
return state_dict
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def _filter_to_diffusers_lora_keys(state_dict: dict) -> tuple[dict, dict]:
|
| 676 |
+
"""Return (filtered_state_dict, stats).
|
| 677 |
+
|
| 678 |
+
Some ComfyUI/Qwen safetensors (especially "merged" variants) include non-LoRA
|
| 679 |
+
delta/patch keys like `*.diff` and `*.diff_b` alongside real LoRA tensors.
|
| 680 |
+
Diffusers' internal Qwen LoRA converter is strict: any leftover keys cause an
|
| 681 |
+
error (`state_dict should be empty...`).
|
| 682 |
+
|
| 683 |
+
This helper keeps only the keys Diffusers can consume as a LoRA:
|
| 684 |
+
- `*.lora_up.weight`
|
| 685 |
+
- `*.lora_down.weight`
|
| 686 |
+
- (rare) `*.lora_mid.weight`
|
| 687 |
+
- alpha keys: `*.alpha` (or `*.lora_alpha` which we normalize to `*.alpha`)
|
| 688 |
+
|
| 689 |
+
It also drops known patch keys (`*.diff`, `*.diff_b`) and everything else.
|
| 690 |
+
"""
|
| 691 |
+
|
| 692 |
+
keep_suffixes = (
|
| 693 |
+
".lora_up.weight",
|
| 694 |
+
".lora_down.weight",
|
| 695 |
+
".lora_mid.weight",
|
| 696 |
+
".alpha",
|
| 697 |
+
".lora_alpha",
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
dropped_patch = 0
|
| 701 |
+
dropped_other = 0
|
| 702 |
+
kept = 0
|
| 703 |
+
normalized_alpha = 0
|
| 704 |
+
|
| 705 |
+
out: dict[str, torch.Tensor] = {}
|
| 706 |
+
for k, v in state_dict.items():
|
| 707 |
+
if not isinstance(v, torch.Tensor):
|
| 708 |
+
# Ignore non-tensor entries if any.
|
| 709 |
+
dropped_other += 1
|
| 710 |
+
continue
|
| 711 |
+
|
| 712 |
+
# Drop ComfyUI "delta" keys that Diffusers' LoRA loader will never consume.
|
| 713 |
+
if k.endswith(".diff") or k.endswith(".diff_b"):
|
| 714 |
+
dropped_patch += 1
|
| 715 |
+
continue
|
| 716 |
+
|
| 717 |
+
if not k.endswith(keep_suffixes):
|
| 718 |
+
dropped_other += 1
|
| 719 |
+
continue
|
| 720 |
+
|
| 721 |
+
if k.endswith(".lora_alpha"):
|
| 722 |
+
# Normalize common alt name to what Diffusers expects.
|
| 723 |
+
base = k[: -len(".lora_alpha")]
|
| 724 |
+
k2 = f"{base}.alpha"
|
| 725 |
+
out[k2] = v.float() if v.dtype != torch.float32 else v
|
| 726 |
+
normalized_alpha += 1
|
| 727 |
+
kept += 1
|
| 728 |
+
continue
|
| 729 |
+
|
| 730 |
+
out[k] = v
|
| 731 |
+
kept += 1
|
| 732 |
+
|
| 733 |
+
stats = {
|
| 734 |
+
"kept": kept,
|
| 735 |
+
"dropped_patch": dropped_patch,
|
| 736 |
+
"dropped_other": dropped_other,
|
| 737 |
+
"normalized_alpha": normalized_alpha,
|
| 738 |
+
}
|
| 739 |
+
return out, stats
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def _duplicate_stripped_prefix_keys(state_dict: dict, prefix: str = "diffusion_model.") -> dict:
|
| 743 |
+
"""Ensure both prefixed and unprefixed variants exist for LoRA-related keys.
|
| 744 |
+
|
| 745 |
+
Diffusers' Qwen LoRA conversion may strip `diffusion_model.` when looking up
|
| 746 |
+
modules. Some exports only include prefixed keys. To be maximally compatible,
|
| 747 |
+
we duplicate LoRA keys (and alpha) in stripped form when missing.
|
| 748 |
+
"""
|
| 749 |
+
|
| 750 |
+
out = dict(state_dict)
|
| 751 |
+
for k, v in list(state_dict.items()):
|
| 752 |
+
if not k.startswith(prefix):
|
| 753 |
+
continue
|
| 754 |
+
stripped = k[len(prefix) :]
|
| 755 |
+
if stripped not in out:
|
| 756 |
+
out[stripped] = v
|
| 757 |
+
return out
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name: str, needs_alpha_fix: bool = False):
|
| 761 |
+
"""
|
| 762 |
+
Normal path: pipe.load_lora_weights(repo, weight_name=..., adapter_name=...)
|
| 763 |
+
BFS fallback: download safetensors, inject missing alpha keys, then load from dict.
|
| 764 |
+
"""
|
| 765 |
+
try:
|
| 766 |
+
pipe.load_lora_weights(repo, weight_name=weight_name, adapter_name=adapter_name)
|
| 767 |
+
return
|
| 768 |
+
except (KeyError, ValueError) as e:
|
| 769 |
+
# KeyError: missing required alpha keys (common in BFS)
|
| 770 |
+
# ValueError: Diffusers Qwen converter found leftover keys (e.g. .diff/.diff_b)
|
| 771 |
+
if not needs_alpha_fix:
|
| 772 |
+
raise
|
| 773 |
+
|
| 774 |
+
print(
|
| 775 |
+
"⚠️ LoRA load failed (will try safe dict fallback). "
|
| 776 |
+
f"Adapter={adapter_name!r} file={weight_name!r} error={type(e).__name__}: {e}"
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
local_path = hf_hub_download(repo_id=repo, filename=weight_name)
|
| 780 |
+
sd = safetensors_load_file(local_path)
|
| 781 |
+
|
| 782 |
+
# 1) Inject required `<module>.alpha` keys (neutral scaling alpha=rank).
|
| 783 |
+
sd = _inject_missing_alpha_keys(sd)
|
| 784 |
+
|
| 785 |
+
# 2) Keep only LoRA + alpha keys; drop ComfyUI patch/delta keys.
|
| 786 |
+
sd, stats = _filter_to_diffusers_lora_keys(sd)
|
| 787 |
+
|
| 788 |
+
# 3) Duplicate stripped keys (remove `diffusion_model.`) for compatibility.
|
| 789 |
+
sd = _duplicate_stripped_prefix_keys(sd)
|
| 790 |
+
|
| 791 |
+
print(
|
| 792 |
+
"🧹 LoRA dict cleanup stats: "
|
| 793 |
+
f"kept={stats['kept']} dropped_patch={stats['dropped_patch']} "
|
| 794 |
+
f"dropped_other={stats['dropped_other']} normalized_alpha={stats['normalized_alpha']}"
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
pipe.load_lora_weights(sd, adapter_name=adapter_name)
|
| 798 |
+
return
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
# ============================================================
|
| 802 |
+
# LoRA loader: single/package + strengths
|
| 803 |
+
# ============================================================
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
def _ensure_loaded_and_get_active_adapters(selected_lora: str):
|
| 807 |
+
spec = ADAPTER_SPECS.get(selected_lora)
|
| 808 |
+
if not spec:
|
| 809 |
+
raise gr.Error(f"Configuration not found for: {selected_lora}")
|
| 810 |
+
|
| 811 |
+
adapter_names = []
|
| 812 |
+
adapter_weights = []
|
| 813 |
+
|
| 814 |
+
if spec.get("type") == "package":
|
| 815 |
+
parts = spec.get("parts", [])
|
| 816 |
+
if not parts:
|
| 817 |
+
raise gr.Error(f"Package spec has no parts: {selected_lora}")
|
| 818 |
+
|
| 819 |
+
for part in parts:
|
| 820 |
+
repo = part["repo"]
|
| 821 |
+
weights = part["weights"]
|
| 822 |
+
adapter_name = part["adapter_name"]
|
| 823 |
+
strength = float(part.get("strength", 1.0))
|
| 824 |
+
needs_alpha_fix = bool(part.get("needs_alpha_fix", False))
|
| 825 |
+
|
| 826 |
+
if adapter_name not in LOADED_ADAPTERS:
|
| 827 |
+
print(f"--- Downloading and Loading Adapter Part: {selected_lora} / {adapter_name} ---")
|
| 828 |
+
try:
|
| 829 |
+
_load_lora_weights_with_fallback(
|
| 830 |
+
repo=repo,
|
| 831 |
+
weight_name=weights,
|
| 832 |
+
adapter_name=adapter_name,
|
| 833 |
+
needs_alpha_fix=needs_alpha_fix,
|
| 834 |
+
)
|
| 835 |
+
LOADED_ADAPTERS.add(adapter_name)
|
| 836 |
+
except Exception as e:
|
| 837 |
+
raise gr.Error(f"Failed to load adapter part {selected_lora}/{adapter_name}: {e}")
|
| 838 |
+
else:
|
| 839 |
+
print(f"--- Adapter part already loaded: {selected_lora} / {adapter_name} ---")
|
| 840 |
+
|
| 841 |
+
adapter_names.append(adapter_name)
|
| 842 |
+
adapter_weights.append(strength)
|
| 843 |
+
|
| 844 |
+
else:
|
| 845 |
+
repo = spec["repo"]
|
| 846 |
+
weights = spec["weights"]
|
| 847 |
+
adapter_name = spec["adapter_name"]
|
| 848 |
+
strength = float(spec.get("strength", 1.0))
|
| 849 |
+
needs_alpha_fix = bool(spec.get("needs_alpha_fix", False))
|
| 850 |
+
|
| 851 |
+
if adapter_name not in LOADED_ADAPTERS:
|
| 852 |
+
print(f"--- Downloading and Loading Adapter: {selected_lora} ---")
|
| 853 |
+
try:
|
| 854 |
+
_load_lora_weights_with_fallback(
|
| 855 |
+
repo=repo,
|
| 856 |
+
weight_name=weights,
|
| 857 |
+
adapter_name=adapter_name,
|
| 858 |
+
needs_alpha_fix=needs_alpha_fix,
|
| 859 |
+
)
|
| 860 |
+
LOADED_ADAPTERS.add(adapter_name)
|
| 861 |
+
except Exception as e:
|
| 862 |
+
raise gr.Error(f"Failed to load adapter {selected_lora}: {e}")
|
| 863 |
+
else:
|
| 864 |
+
print(f"--- Adapter {selected_lora} is already loaded. ---")
|
| 865 |
+
|
| 866 |
+
adapter_names = [adapter_name]
|
| 867 |
+
adapter_weights = [strength]
|
| 868 |
+
|
| 869 |
+
return adapter_names, adapter_weights
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
# ============================================================
|
| 873 |
+
# UI handlers
|
| 874 |
+
# ============================================================
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
def on_lora_change_ui(selected_lora, current_prompt, extras_condition_only):
|
| 879 |
+
prompt_val = current_prompt
|
| 880 |
+
if selected_lora != NONE_LORA:
|
| 881 |
+
preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
|
| 882 |
+
if preset:
|
| 883 |
+
prompt_val = preset
|
| 884 |
+
else:
|
| 885 |
+
prompt_val = "" # CLEAR THE PROMPT IF ACTIVE BUT NO PRESET
|
| 886 |
+
|
| 887 |
+
prompt_update = gr.update(value=prompt_val)
|
| 888 |
+
camera_update = gr.update(visible=(selected_lora == "3D-Camera"))
|
| 889 |
+
|
| 890 |
+
# Image2 visibility/label
|
| 891 |
+
if lora_requires_two_images(selected_lora):
|
| 892 |
+
img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
|
| 893 |
+
else:
|
| 894 |
+
img2_update = gr.update(visible=False, value=None, label='Upload Reference (Image 2)')
|
| 895 |
+
|
| 896 |
+
# Extra references routing default
|
| 897 |
+
if selected_lora in ('BFS-Best-FaceSwap', 'BFS-Best-FaceSwap-merge', 'AnyPose'):
|
| 898 |
+
extras_update = gr.update(value=True)
|
| 899 |
+
else:
|
| 900 |
+
extras_update = gr.update(value=extras_condition_only)
|
| 901 |
+
|
| 902 |
+
return prompt_update, img2_update, extras_update, camera_update
|
| 903 |
+
# ============================================================
|
| 904 |
+
# UI helpers: output routing + derived conditioning
|
| 905 |
+
|
| 906 |
+
def _append_to_gallery(existing_gallery, new_image):
|
| 907 |
+
if existing_gallery is None:
|
| 908 |
+
return [new_image]
|
| 909 |
+
if not isinstance(existing_gallery, list):
|
| 910 |
+
existing_gallery = [existing_gallery]
|
| 911 |
+
existing_gallery.append(new_image)
|
| 912 |
+
return existing_gallery
|
| 913 |
+
|
| 914 |
+
# ============================================================
|
| 915 |
+
|
| 916 |
+
def set_output_as_image1(last):
|
| 917 |
+
if last is None:
|
| 918 |
+
raise gr.Error("No output available yet.")
|
| 919 |
+
return gr.update(value=last)
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
def set_output_as_image2(last):
|
| 923 |
+
if last is None:
|
| 924 |
+
raise gr.Error("No output available yet.")
|
| 925 |
+
return gr.update(value=last)
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
def set_output_as_extra(last, existing_extra):
|
| 929 |
+
if last is None:
|
| 930 |
+
raise gr.Error("No output available yet.")
|
| 931 |
+
return _append_to_gallery(existing_extra, last)
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
@spaces.GPU
|
| 935 |
+
def add_derived_ref(img1, existing_extra, derived_type, derived_use_gpu):
|
| 936 |
+
if img1 is None:
|
| 937 |
+
raise gr.Error("Please upload Image 1 first.")
|
| 938 |
+
|
| 939 |
+
if derived_type == "None":
|
| 940 |
+
return gr.update(value=existing_extra), gr.update(visible=False, value=None)
|
| 941 |
+
|
| 942 |
+
base = img1.convert("RGB")
|
| 943 |
+
|
| 944 |
+
if derived_type == "Depth (Depth Anything V2 Small)":
|
| 945 |
+
derived = make_depth_map(base, use_gpu=bool(derived_use_gpu))
|
| 946 |
+
else:
|
| 947 |
+
raise gr.Error(f"Unknown derived type: {derived_type}")
|
| 948 |
+
|
| 949 |
+
new_gallery = _append_to_gallery(existing_extra, derived)
|
| 950 |
+
return gr.update(value=new_gallery), gr.update(visible=True, value=derived)
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
# ============================================================
|
| 954 |
+
# Inference
|
| 955 |
+
# ============================================================
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
@spaces.GPU
|
| 959 |
+
def infer(
|
| 960 |
+
input_image_1,
|
| 961 |
+
input_image_2,
|
| 962 |
+
input_images_extra, # gallery multi-image box
|
| 963 |
+
prompt,
|
| 964 |
+
lora_adapter,
|
| 965 |
+
seed,
|
| 966 |
+
randomize_seed,
|
| 967 |
+
guidance_scale,
|
| 968 |
+
steps,
|
| 969 |
+
target_megapixels,
|
| 970 |
+
extras_condition_only,
|
| 971 |
+
pad_to_canvas,
|
| 972 |
+
vae_tiling, # VAE tiling toggle
|
| 973 |
+
resolution_multiple,
|
| 974 |
+
vae_ref_megapixels,
|
| 975 |
+
decoder_vae,
|
| 976 |
+
keep_decoder_2x,
|
| 977 |
+
progress=gr.Progress(track_tqdm=True),
|
| 978 |
+
):
|
| 979 |
+
gc.collect()
|
| 980 |
+
if torch.cuda.is_available():
|
| 981 |
+
torch.cuda.empty_cache()
|
| 982 |
+
|
| 983 |
+
if input_image_1 is None:
|
| 984 |
+
raise gr.Error("Please upload Image 1.")
|
| 985 |
+
|
| 986 |
+
# Handle "None"
|
| 987 |
+
if lora_adapter == NONE_LORA:
|
| 988 |
+
try:
|
| 989 |
+
pipe.set_adapters([], adapter_weights=[])
|
| 990 |
+
except Exception:
|
| 991 |
+
if LOADED_ADAPTERS:
|
| 992 |
+
pipe.set_adapters(list(LOADED_ADAPTERS), adapter_weights=[0.0] * len(LOADED_ADAPTERS))
|
| 993 |
+
else:
|
| 994 |
+
adapter_names, adapter_weights = _ensure_loaded_and_get_active_adapters(lora_adapter)
|
| 995 |
+
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
| 996 |
+
|
| 997 |
+
if randomize_seed:
|
| 998 |
+
seed = random.randint(0, MAX_SEED)
|
| 999 |
+
|
| 1000 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 1001 |
+
negative_prompt = (
|
| 1002 |
+
"worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, "
|
| 1003 |
+
"extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
img1 = input_image_1.convert("RGB")
|
| 1007 |
+
img2 = input_image_2.convert("RGB") if input_image_2 is not None else None
|
| 1008 |
+
|
| 1009 |
+
# Normalize extra images (Gallery) to PIL RGB (handles tuples from Gallery)
|
| 1010 |
+
extra_imgs: list[Image.Image] = []
|
| 1011 |
+
if input_images_extra:
|
| 1012 |
+
for item in input_images_extra:
|
| 1013 |
+
pil = _to_pil_rgb(item)
|
| 1014 |
+
if pil is not None:
|
| 1015 |
+
extra_imgs.append(pil)
|
| 1016 |
+
|
| 1017 |
+
# Enforce existing 2-image LoRA behavior (image_1 + image_2 required)
|
| 1018 |
+
if lora_requires_two_images(lora_adapter) and img2 is None:
|
| 1019 |
+
raise gr.Error("This LoRA needs two images. Please upload Image 2 as well.")
|
| 1020 |
+
|
| 1021 |
+
# Label images as image_1, image_2, image_3...
|
| 1022 |
+
labeled = build_labeled_images(img1, img2, extra_imgs)
|
| 1023 |
+
|
| 1024 |
+
# Pass to pipeline in labeled order. Keep single-image call when only one is present.
|
| 1025 |
+
pipe_images = list(labeled.values())
|
| 1026 |
+
if len(pipe_images) == 1:
|
| 1027 |
+
pipe_images = pipe_images[0]
|
| 1028 |
+
|
| 1029 |
+
# Resolution derived from Image 1 (base/body/target)
|
| 1030 |
+
# Use target *area* (≈ megapixels) rather than long-edge sizing to reduce FOV drift.
|
| 1031 |
+
target_area = get_target_area_for_lora(img1, lora_adapter, float(target_megapixels))
|
| 1032 |
+
width, height = compute_canvas_dimensions_from_area(
|
| 1033 |
+
img1,
|
| 1034 |
+
target_area=target_area,
|
| 1035 |
+
multiple_of=int(resolution_multiple),
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
# Decide which images participate in the VAE latent stream.
|
| 1039 |
+
# If enabled, extra references beyond (Img_1, Img_2) become conditioning-only.
|
| 1040 |
+
vae_image_indices = None
|
| 1041 |
+
if extras_condition_only:
|
| 1042 |
+
if isinstance(pipe_images, list) and len(pipe_images) > 2:
|
| 1043 |
+
vae_image_indices = [0, 1] if len(pipe_images) >= 2 else [0]
|
| 1044 |
+
|
| 1045 |
+
try:
|
| 1046 |
+
print(
|
| 1047 |
+
"[DEBUG][infer] submitting request | "
|
| 1048 |
+
f"lora_adapter={lora_adapter!r} seed={seed} prompt={prompt!r}"
|
| 1049 |
+
)
|
| 1050 |
+
print(f"[DEBUG][infer] canvas={width}x{height} (~{(width*height)/1_048_576:.3f} MP) vae_tiling={bool(vae_tiling)}")
|
| 1051 |
+
|
| 1052 |
+
# ✅ Apply UI toggle per-request (OFF by default)
|
| 1053 |
+
# Lattice multiple passed to pipeline too (anti-drift / valid size grid)
|
| 1054 |
+
res_mult = int(resolution_multiple) if resolution_multiple is not None else int(pipe.vae_scale_factor * 2)
|
| 1055 |
+
|
| 1056 |
+
# Optional: override VAE sizing for *extra* references (beyond Image 1 / Image 2)
|
| 1057 |
+
# Interpreted as megapixels; 0 disables override (uses canvas).
|
| 1058 |
+
try:
|
| 1059 |
+
mp_ref = float(vae_ref_megapixels)
|
| 1060 |
+
except Exception:
|
| 1061 |
+
mp_ref = 0.0
|
| 1062 |
+
|
| 1063 |
+
vae_ref_area = int(mp_ref * 1024 * 1024) if mp_ref and mp_ref > 0 else None
|
| 1064 |
+
|
| 1065 |
+
# Extras start index depends on whether Image 2 exists
|
| 1066 |
+
base_ref_count = 2 if img2 is not None else 1
|
| 1067 |
+
|
| 1068 |
+
_apply_vae_tiling(bool(vae_tiling))
|
| 1069 |
+
|
| 1070 |
+
result = pipe(
|
| 1071 |
+
image=pipe_images,
|
| 1072 |
+
prompt=prompt,
|
| 1073 |
+
negative_prompt=negative_prompt,
|
| 1074 |
+
height=height,
|
| 1075 |
+
width=width,
|
| 1076 |
+
num_inference_steps=steps,
|
| 1077 |
+
generator=generator,
|
| 1078 |
+
true_cfg_scale=guidance_scale,
|
| 1079 |
+
vae_image_indices=vae_image_indices,
|
| 1080 |
+
pad_to_canvas=bool(pad_to_canvas),
|
| 1081 |
+
resolution_multiple=res_mult,
|
| 1082 |
+
vae_ref_area=vae_ref_area,
|
| 1083 |
+
vae_ref_start_index=base_ref_count,
|
| 1084 |
+
decoder_vae=str(decoder_vae).lower(),
|
| 1085 |
+
keep_decoder_2x=bool(keep_decoder_2x),
|
| 1086 |
+
).images[0]
|
| 1087 |
+
return result, seed, result
|
| 1088 |
+
finally:
|
| 1089 |
+
gc.collect()
|
| 1090 |
+
if torch.cuda.is_available():
|
| 1091 |
+
torch.cuda.empty_cache()
|
| 1092 |
+
|
| 1093 |
+
|
| 1094 |
+
@spaces.GPU
|
| 1095 |
+
def infer_example(input_image, prompt, lora_adapter):
|
| 1096 |
+
if input_image is None:
|
| 1097 |
+
return None, 0, None
|
| 1098 |
+
input_pil = input_image.convert("RGB")
|
| 1099 |
+
guidance_scale = 1.0
|
| 1100 |
+
steps = 4
|
| 1101 |
+
# Examples don't supply Image 2 or extra images; and example list doesn't include AnyPose/BFS.
|
| 1102 |
+
# Keep VAE tiling OFF in examples (matches default).
|
| 1103 |
+
result, seed, last = infer(
|
| 1104 |
+
input_pil,
|
| 1105 |
+
None,
|
| 1106 |
+
None,
|
| 1107 |
+
prompt,
|
| 1108 |
+
lora_adapter,
|
| 1109 |
+
0,
|
| 1110 |
+
True,
|
| 1111 |
+
guidance_scale,
|
| 1112 |
+
steps,
|
| 1113 |
+
1.0,
|
| 1114 |
+
True,
|
| 1115 |
+
True,
|
| 1116 |
+
False, # vae_tiling
|
| 1117 |
+
)
|
| 1118 |
+
return result, seed, last
|
| 1119 |
+
|
| 1120 |
+
|
| 1121 |
+
# ============================================================
|
| 1122 |
+
# UI
|
| 1123 |
+
# ============================================================
|
| 1124 |
+
|
| 1125 |
+
css = """
|
| 1126 |
+
#col-container {
|
| 1127 |
+
margin: 0 auto;
|
| 1128 |
+
max-width: 960px;
|
| 1129 |
+
}
|
| 1130 |
+
#main-title h1 {font-size: 2.1em !important;}
|
| 1131 |
+
"""
|
| 1132 |
+
|
| 1133 |
+
aio_status_line = (
|
| 1134 |
+
f"**AIO transformer version:** `{AIO_VERSION}` "
|
| 1135 |
+
f"({AIO_VERSION_SOURCE}; env `AIO_VERSION`={_AIO_ENV_RAW!r})"
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
with gr.Blocks() as demo:
|
| 1139 |
+
with gr.Column(elem_id="col-container"):
|
| 1140 |
+
gr.Markdown("# **Qwen-Image-Edit-2511-LoRAs-Fast**", elem_id="main-title")
|
| 1141 |
+
gr.Markdown(
|
| 1142 |
+
f"""This **experimental** space for [QIE-2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) utilizes [extracted transformers](https://huggingface.co/Pr0f3ssi0n4ln00b/Phr00t-Qwen-Rapid-AIO) of [Phr00t’s Rapid AIO merge](https://huggingface.co/Phr00t/Qwen-Image-Edit-Rapid-AIO) and FA3-optimization with [LoRA](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image-Edit-2511) support and a couple of extra features:
|
| 1143 |
+
|
| 1144 |
+
- Optional conditioning-only routing for extra reference latents
|
| 1145 |
+
- Uncapped canvas resolution
|
| 1146 |
+
- Optional VAE tiling for high resolutions
|
| 1147 |
+
- Optional depth mapping for conditioning
|
| 1148 |
+
- Optional routing of output to input for further iterations
|
| 1149 |
+
- Optional alternative decoder [VAE](https://huggingface.co/spacepxl/Wan2.1-VAE-upscale2x/tree/main/diffusers/Wan2.1_VAE_upscale2x_imageonly_real_v1)
|
| 1150 |
+
|
| 1151 |
+
Current environment is running **{AIO_VERSION}** of the Rapid AIO. Duplicate the space and set the **AIO_VERSION** space variable to use a different version."""
|
| 1152 |
+
)
|
| 1153 |
+
gr.Markdown(aio_status_line)
|
| 1154 |
+
|
| 1155 |
+
with gr.Row(equal_height=True):
|
| 1156 |
+
with gr.Column():
|
| 1157 |
+
input_image_1 = gr.Image(label="Upload Image 1 (Base / Target)", type="pil", )
|
| 1158 |
+
|
| 1159 |
+
input_image_2 = gr.Image(label="Upload Reference (Image 2)", type="pil", height=290, visible=False)
|
| 1160 |
+
|
| 1161 |
+
with gr.Column(visible=False) as camera_container:
|
| 1162 |
+
gr.Markdown("### 🎮 3D Camera Control\n*Drag handles: 🟢 Azimuth, 🩷 Elevation, 🟠 Distance*")
|
| 1163 |
+
camera_3d = CameraControl3D(value={"azimuth": 0, "elevation": 0, "distance": 1.0}, elem_id="camera-3d-control")
|
| 1164 |
+
gr.Markdown("### 🎚️ Slider Controls")
|
| 1165 |
+
azimuth_slider = gr.Slider(label="Azimuth", minimum=0, maximum=315, step=45, value=0, info="0°=front, 90°=right, 180°=back, 270°=left")
|
| 1166 |
+
elevation_slider = gr.Slider(label="Elevation", minimum=-30, maximum=60, step=30, value=0, info="-30°=low angle, 0°=eye, 60°=high angle")
|
| 1167 |
+
distance_slider = gr.Slider(label="Distance", minimum=0.6, maximum=1.4, step=0.4, value=1.0, info="0.6=close, 1.0=medium, 1.4=wide")
|
| 1168 |
+
|
| 1169 |
+
|
| 1170 |
+
input_images_extra = gr.Gallery(
|
| 1171 |
+
label="Upload Additional Images (auto-indexed after Image 1/2)",
|
| 1172 |
+
type="pil",
|
| 1173 |
+
height=290,
|
| 1174 |
+
columns=4,
|
| 1175 |
+
rows=2,
|
| 1176 |
+
interactive=True,
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
prompt = gr.Text(
|
| 1180 |
+
label="Edit Prompt",
|
| 1181 |
+
show_label=True,
|
| 1182 |
+
placeholder="e.g., transform into photo..",
|
| 1183 |
+
)
|
| 1184 |
+
|
| 1185 |
+
run_button = gr.Button("Edit Image", variant="primary")
|
| 1186 |
+
|
| 1187 |
+
with gr.Column():
|
| 1188 |
+
output_image = gr.Image(label="Output Image", interactive=False, format="png", height=353)
|
| 1189 |
+
|
| 1190 |
+
last_output = gr.State(value=None)
|
| 1191 |
+
|
| 1192 |
+
with gr.Row():
|
| 1193 |
+
btn_out_to_img1 = gr.Button("⬅️ Output → Image 1", variant="secondary")
|
| 1194 |
+
btn_out_to_img2 = gr.Button("⬅️ Output → Image 2", variant="secondary")
|
| 1195 |
+
btn_out_to_extra = gr.Button("➕ Output → Extra Ref", variant="secondary")
|
| 1196 |
+
|
| 1197 |
+
derived_preview = gr.Image(
|
| 1198 |
+
label="Derived Conditioning Preview",
|
| 1199 |
+
interactive=False,
|
| 1200 |
+
format="png",
|
| 1201 |
+
height=200,
|
| 1202 |
+
visible=False,
|
| 1203 |
+
)
|
| 1204 |
+
|
| 1205 |
+
with gr.Row():
|
| 1206 |
+
lora_choices = [NONE_LORA] + list(ADAPTER_SPECS.keys())
|
| 1207 |
+
lora_adapter = gr.Dropdown(
|
| 1208 |
+
label="Choose Editing Style",
|
| 1209 |
+
choices=lora_choices,
|
| 1210 |
+
value=NONE_LORA,
|
| 1211 |
+
)
|
| 1212 |
+
|
| 1213 |
+
with gr.Accordion("Advanced Settings", open=False, visible=True):
|
| 1214 |
+
with gr.Accordion("Derived Conditioning (Pose / Depth)", open=False):
|
| 1215 |
+
derived_type = gr.Dropdown(
|
| 1216 |
+
label="Derived Type (from Image 1)",
|
| 1217 |
+
choices=["None", "Depth (Depth Anything V2 Small)"],
|
| 1218 |
+
value="None",
|
| 1219 |
+
)
|
| 1220 |
+
derived_use_gpu = gr.Checkbox(label="Use GPU for derived model", value=False)
|
| 1221 |
+
add_derived_btn = gr.Button("➕ Add derived ref to Extras (conditioning-only recommended)")
|
| 1222 |
+
|
| 1223 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
| 1224 |
+
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
| 1225 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
|
| 1226 |
+
steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=4)
|
| 1227 |
+
target_megapixels = gr.Slider(
|
| 1228 |
+
label="Target Megapixels (canvas, 0 = match input area)",
|
| 1229 |
+
minimum=0.0,
|
| 1230 |
+
maximum=6.0,
|
| 1231 |
+
step=0.1,
|
| 1232 |
+
value=1.0,
|
| 1233 |
+
)
|
| 1234 |
+
resolution_multiple = gr.Dropdown(
|
| 1235 |
+
label="Resolution lattice multiple (anti-drift)",
|
| 1236 |
+
choices=[32, 56, 112],
|
| 1237 |
+
value=32,
|
| 1238 |
+
interactive=True,
|
| 1239 |
+
)
|
| 1240 |
+
vae_ref_megapixels = gr.Slider(
|
| 1241 |
+
label="Extra refs VAE megapixels override (0 = use canvas)",
|
| 1242 |
+
minimum=0.0,
|
| 1243 |
+
maximum=6.0,
|
| 1244 |
+
step=0.1,
|
| 1245 |
+
value=0.0,
|
| 1246 |
+
)
|
| 1247 |
+
decoder_vae = gr.Dropdown(
|
| 1248 |
+
label="Decoder VAE",
|
| 1249 |
+
choices=["qwen", "wan2x"],
|
| 1250 |
+
value="qwen",
|
| 1251 |
+
interactive=True,
|
| 1252 |
+
)
|
| 1253 |
+
keep_decoder_2x = gr.Checkbox(
|
| 1254 |
+
label="Keep 2× output (wan2x only)",
|
| 1255 |
+
value=False,
|
| 1256 |
+
)
|
| 1257 |
+
extras_condition_only = gr.Checkbox(
|
| 1258 |
+
label="Extra references are conditioning-only (exclude from VAE)",
|
| 1259 |
+
value=True,
|
| 1260 |
+
)
|
| 1261 |
+
pad_to_canvas = gr.Checkbox(
|
| 1262 |
+
label="Pad images to canvas aspect (avoid warping)",
|
| 1263 |
+
value=True,
|
| 1264 |
+
)
|
| 1265 |
+
|
| 1266 |
+
# ✅ NEW: VAE tiling toggle (OFF by default)
|
| 1267 |
+
vae_tiling = gr.Checkbox(
|
| 1268 |
+
label="VAE tiling (lower VRAM, slower)",
|
| 1269 |
+
value=False,
|
| 1270 |
+
)
|
| 1271 |
+
|
| 1272 |
+
# On LoRA selection: preset prompt + toggle Image 2
|
| 1273 |
+
lora_adapter.change(
|
| 1274 |
+
fn=on_lora_change_ui,
|
| 1275 |
+
inputs=[lora_adapter, prompt, extras_condition_only],
|
| 1276 |
+
outputs=[prompt, input_image_2, extras_condition_only, camera_container],
|
| 1277 |
+
)
|
| 1278 |
+
|
| 1279 |
+
# Examples removed automatically by setup_manager
|
| 1280 |
+
|
| 1281 |
+
|
| 1282 |
+
# --- 3D Camera Events ---
|
| 1283 |
+
def update_prompt_from_sliders(az, el, dist, curr_prompt):
|
| 1284 |
+
return update_prompt_with_camera(az, el, dist, curr_prompt)
|
| 1285 |
+
|
| 1286 |
+
def sync_3d_to_sliders(cv, curr_prompt):
|
| 1287 |
+
if cv and isinstance(cv, dict):
|
| 1288 |
+
az = cv.get('azimuth', 0)
|
| 1289 |
+
el = cv.get('elevation', 0)
|
| 1290 |
+
dist = cv.get('distance', 1.0)
|
| 1291 |
+
return az, el, dist, update_prompt_with_camera(az, el, dist, curr_prompt)
|
| 1292 |
+
return gr.update(), gr.update(), gr.update(), gr.update()
|
| 1293 |
+
|
| 1294 |
+
def sync_sliders_to_3d(az, el, dist):
|
| 1295 |
+
return {"azimuth": az, "elevation": el, "distance": dist}
|
| 1296 |
+
|
| 1297 |
+
|
| 1298 |
+
def update_3d_image(img):
|
| 1299 |
+
if img is None: return gr.update(imageUrl=None)
|
| 1300 |
+
import base64
|
| 1301 |
+
from io import BytesIO
|
| 1302 |
+
buf = BytesIO()
|
| 1303 |
+
img.save(buf, format="PNG")
|
| 1304 |
+
durl = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
|
| 1305 |
+
return gr.update(imageUrl=durl)
|
| 1306 |
+
|
| 1307 |
+
for slider in [azimuth_slider, elevation_slider, distance_slider]:
|
| 1308 |
+
slider.change(fn=update_prompt_from_sliders, inputs=[azimuth_slider, elevation_slider, distance_slider, prompt], outputs=[prompt])
|
| 1309 |
+
slider.release(fn=sync_sliders_to_3d, inputs=[azimuth_slider, elevation_slider, distance_slider], outputs=[camera_3d])
|
| 1310 |
+
|
| 1311 |
+
camera_3d.change(fn=sync_3d_to_sliders, inputs=[camera_3d, prompt], outputs=[azimuth_slider, elevation_slider, distance_slider, prompt])
|
| 1312 |
+
|
| 1313 |
+
input_image_1.upload(fn=update_3d_image, inputs=[input_image_1], outputs=[camera_3d])
|
| 1314 |
+
input_image_1.clear(fn=lambda: gr.update(imageUrl=None), outputs=[camera_3d])
|
| 1315 |
+
|
| 1316 |
+
run_button.click(
|
| 1317 |
+
fn=infer,
|
| 1318 |
+
inputs=[
|
| 1319 |
+
input_image_1,
|
| 1320 |
+
input_image_2,
|
| 1321 |
+
input_images_extra,
|
| 1322 |
+
prompt,
|
| 1323 |
+
lora_adapter,
|
| 1324 |
+
seed,
|
| 1325 |
+
randomize_seed,
|
| 1326 |
+
guidance_scale,
|
| 1327 |
+
steps,
|
| 1328 |
+
target_megapixels,
|
| 1329 |
+
extras_condition_only,
|
| 1330 |
+
pad_to_canvas,
|
| 1331 |
+
vae_tiling,
|
| 1332 |
+
resolution_multiple,
|
| 1333 |
+
vae_ref_megapixels,
|
| 1334 |
+
decoder_vae,
|
| 1335 |
+
keep_decoder_2x,
|
| 1336 |
+
],
|
| 1337 |
+
outputs=[output_image, seed, last_output],
|
| 1338 |
+
)
|
| 1339 |
+
|
| 1340 |
+
# Output routing buttons
|
| 1341 |
+
btn_out_to_img1.click(fn=set_output_as_image1, inputs=[last_output], outputs=[input_image_1])
|
| 1342 |
+
btn_out_to_img2.click(fn=set_output_as_image2, inputs=[last_output], outputs=[input_image_2])
|
| 1343 |
+
btn_out_to_extra.click(fn=set_output_as_extra, inputs=[last_output, input_images_extra], outputs=[input_images_extra])
|
| 1344 |
+
|
| 1345 |
+
# Derived conditioning: append pose/depth map as extra ref (UI shows preview)
|
| 1346 |
+
add_derived_btn.click(
|
| 1347 |
+
fn=add_derived_ref,
|
| 1348 |
+
inputs=[input_image_1, input_images_extra, derived_type, derived_use_gpu],
|
| 1349 |
+
outputs=[input_images_extra, derived_preview],
|
| 1350 |
+
)
|
| 1351 |
+
|
| 1352 |
+
if __name__ == "__main__":
|
| 1353 |
+
head = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
|
| 1354 |
+
demo.queue(max_size=30).launch(head=head, server_name="0.0.0.0", share=True,
|
| 1355 |
+
css=css,
|
| 1356 |
+
theme=orange_red_theme,
|
| 1357 |
+
mcp_server=True,
|
| 1358 |
+
ssr_mode=False,
|
| 1359 |
+
show_error=True,
|
| 1360 |
+
)
|
| 1361 |
+
|
| 1362 |
+
# Manual Patch for missing prompts
|
| 1363 |
+
try:
|
| 1364 |
+
LORA_PRESET_PROMPTS.update({
|
| 1365 |
+
"Consistance": "improve consistency and quality of the generated image",
|
| 1366 |
+
"F2P": "transform the image into a high-quality photo with realistic details",
|
| 1367 |
+
"Multiple-Angles": "change the camera angle of the image",
|
| 1368 |
+
"Light-Restoration": "Remove shadows and relight the image using soft lighting",
|
| 1369 |
+
"Relight": "Relight the image with cinematic lighting",
|
| 1370 |
+
"Multi-Angle-Lighting": "Change the lighting direction and intensity",
|
| 1371 |
+
"Edit-Skin": "Enhance skin textures and natural details",
|
| 1372 |
+
"Next-Scene": "Generate the next scene based on the current image",
|
| 1373 |
+
"Flat-Log": "Desaturate and lower contrast for a flat log look",
|
| 1374 |
+
"Upscale-Image": "Enhance and sharpen the image details",
|
| 1375 |
+
"BFS-Best-FaceSwap": "head_swap : start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure, mouth, lips and front head of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
|
| 1376 |
+
"BFS-Best-FaceSwap-merge": "head_swap : start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure, mouth, lips and front head of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
|
| 1377 |
+
"Qwen-lora-nsfw": "Convert this picture to artistic style.",
|
| 1378 |
+
})
|
| 1379 |
+
except NameError:
|
| 1380 |
+
pass
|
camera_control_ui.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
# Azimuth mappings (8 positions)
|
| 4 |
+
AZIMUTH_MAP = {
|
| 5 |
+
0: "front view",
|
| 6 |
+
45: "front-right quarter view",
|
| 7 |
+
90: "right side view",
|
| 8 |
+
135: "back-right quarter view",
|
| 9 |
+
180: "back view",
|
| 10 |
+
225: "back-left quarter view",
|
| 11 |
+
270: "left side view",
|
| 12 |
+
315: "front-left quarter view"
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
# Elevation mappings (4 positions)
|
| 16 |
+
ELEVATION_MAP = {
|
| 17 |
+
-30: "low-angle shot",
|
| 18 |
+
0: "eye-level shot",
|
| 19 |
+
30: "elevated shot",
|
| 20 |
+
60: "high-angle shot"
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
# Distance mappings (3 positions)
|
| 24 |
+
DISTANCE_MAP = {
|
| 25 |
+
0.6: "close-up",
|
| 26 |
+
1.0: "medium shot",
|
| 27 |
+
1.8: "wide shot"
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def snap_to_nearest(value, options):
|
| 32 |
+
"""Snap a value to the nearest option in a list."""
|
| 33 |
+
return min(options, key=lambda x: abs(x - value))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_camera_prompt(azimuth: float, elevation: float, distance: float) -> str:
|
| 37 |
+
"""
|
| 38 |
+
Build a camera prompt from azimuth, elevation, and distance values.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
azimuth: Horizontal rotation in degrees (0-360)
|
| 42 |
+
elevation: Vertical angle in degrees (-30 to 60)
|
| 43 |
+
distance: Distance factor (0.6 to 1.8)
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Formatted prompt string for the LoRA
|
| 47 |
+
"""
|
| 48 |
+
# Snap to nearest valid values
|
| 49 |
+
azimuth_snapped = snap_to_nearest(azimuth, list(AZIMUTH_MAP.keys()))
|
| 50 |
+
elevation_snapped = snap_to_nearest(elevation, list(ELEVATION_MAP.keys()))
|
| 51 |
+
distance_snapped = snap_to_nearest(distance, list(DISTANCE_MAP.keys()))
|
| 52 |
+
|
| 53 |
+
azimuth_name = AZIMUTH_MAP[azimuth_snapped]
|
| 54 |
+
elevation_name = ELEVATION_MAP[elevation_snapped]
|
| 55 |
+
distance_name = DISTANCE_MAP[distance_snapped]
|
| 56 |
+
|
| 57 |
+
return f"<sks> {azimuth_name} {elevation_name} {distance_name}"
|
| 58 |
+
|
| 59 |
+
def update_prompt_with_camera(azimuth: float, elevation: float, distance: float, current_prompt: str) -> str:
|
| 60 |
+
"""
|
| 61 |
+
Updates the existing prompt by replacing or appending the camera trigger words.
|
| 62 |
+
"""
|
| 63 |
+
import re
|
| 64 |
+
camera_str = build_camera_prompt(azimuth, elevation, distance)
|
| 65 |
+
|
| 66 |
+
if not current_prompt:
|
| 67 |
+
return camera_str
|
| 68 |
+
|
| 69 |
+
# Remove any existing <sks> ... shot tags
|
| 70 |
+
# The pattern matches <sks> followed by any characters until the word "shot"
|
| 71 |
+
clean_prompt = re.sub(r"<sks>.*?shot(?!.*shot)", "", current_prompt).strip()
|
| 72 |
+
|
| 73 |
+
# Clean up multiple spaces
|
| 74 |
+
clean_prompt = re.sub(r"\s+", " ", clean_prompt)
|
| 75 |
+
|
| 76 |
+
if clean_prompt:
|
| 77 |
+
return f"{clean_prompt} {camera_str}"
|
| 78 |
+
return camera_str
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# --- 3D Camera Control Component ---
|
| 83 |
+
class CameraControl3D(gr.HTML):
|
| 84 |
+
"""
|
| 85 |
+
A 3D camera control component using Three.js.
|
| 86 |
+
Outputs: { azimuth: number, elevation: number, distance: number }
|
| 87 |
+
Accepts imageUrl prop to display user's uploaded image on the plane.
|
| 88 |
+
"""
|
| 89 |
+
def __init__(self, value=None, imageUrl=None, **kwargs):
|
| 90 |
+
if value is None:
|
| 91 |
+
value = {"azimuth": 0, "elevation": 0, "distance": 1.0}
|
| 92 |
+
|
| 93 |
+
html_template = """
|
| 94 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
|
| 95 |
+
<div id="camera-control-wrapper" style="width: 100%; height: 450px; position: relative; background: #1a1a1a; border-radius: 12px; overflow: hidden;">
|
| 96 |
+
<div id="prompt-overlay" style="position: absolute; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.8); padding: 8px 16px; border-radius: 8px; font-family: monospace; font-size: 12px; color: #00ff88; white-space: nowrap; z-index: 10;"></div>
|
| 97 |
+
</div>
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
js_on_load = """
|
| 101 |
+
(() => {
|
| 102 |
+
const wrapper = element.querySelector('#camera-control-wrapper');
|
| 103 |
+
const promptOverlay = element.querySelector('#prompt-overlay');
|
| 104 |
+
|
| 105 |
+
// Wait for THREE to load
|
| 106 |
+
const initScene = () => {
|
| 107 |
+
if (typeof THREE === 'undefined') {
|
| 108 |
+
setTimeout(initScene, 100);
|
| 109 |
+
return;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
// Scene setup
|
| 113 |
+
const scene = new THREE.Scene();
|
| 114 |
+
scene.background = new THREE.Color(0x1a1a1a);
|
| 115 |
+
|
| 116 |
+
const camera = new THREE.PerspectiveCamera(50, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
|
| 117 |
+
camera.position.set(4.5, 3, 4.5);
|
| 118 |
+
camera.lookAt(0, 0.75, 0);
|
| 119 |
+
|
| 120 |
+
const renderer = new THREE.WebGLRenderer({ antialias: true });
|
| 121 |
+
renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
|
| 122 |
+
renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
|
| 123 |
+
wrapper.insertBefore(renderer.domElement, promptOverlay);
|
| 124 |
+
|
| 125 |
+
// Lighting
|
| 126 |
+
scene.add(new THREE.AmbientLight(0xffffff, 0.6));
|
| 127 |
+
const dirLight = new THREE.DirectionalLight(0xffffff, 0.6);
|
| 128 |
+
dirLight.position.set(5, 10, 5);
|
| 129 |
+
scene.add(dirLight);
|
| 130 |
+
|
| 131 |
+
// Grid
|
| 132 |
+
scene.add(new THREE.GridHelper(8, 16, 0x333333, 0x222222));
|
| 133 |
+
|
| 134 |
+
// Constants - reduced distances for tighter framing
|
| 135 |
+
const CENTER = new THREE.Vector3(0, 0.75, 0);
|
| 136 |
+
const BASE_DISTANCE = 1.6;
|
| 137 |
+
const AZIMUTH_RADIUS = 2.4;
|
| 138 |
+
const ELEVATION_RADIUS = 1.8;
|
| 139 |
+
|
| 140 |
+
// State
|
| 141 |
+
let azimuthAngle = props.value?.azimuth || 0;
|
| 142 |
+
let elevationAngle = props.value?.elevation || 0;
|
| 143 |
+
let distanceFactor = props.value?.distance || 1.0;
|
| 144 |
+
|
| 145 |
+
// Mappings - reduced wide shot multiplier
|
| 146 |
+
const azimuthSteps = [0, 45, 90, 135, 180, 225, 270, 315];
|
| 147 |
+
const elevationSteps = [-30, 0, 30, 60];
|
| 148 |
+
const distanceSteps = [0.6, 1.0, 1.4];
|
| 149 |
+
|
| 150 |
+
const azimuthNames = {
|
| 151 |
+
0: 'front view', 45: 'front-right quarter view', 90: 'right side view',
|
| 152 |
+
135: 'back-right quarter view', 180: 'back view', 225: 'back-left quarter view',
|
| 153 |
+
270: 'left side view', 315: 'front-left quarter view'
|
| 154 |
+
};
|
| 155 |
+
const elevationNames = { '-30': 'low-angle shot', '0': 'eye-level shot', '30': 'elevated shot', '60': 'high-angle shot' };
|
| 156 |
+
const distanceNames = { '0.6': 'close-up', '1': 'medium shot', '1.4': 'wide shot' };
|
| 157 |
+
|
| 158 |
+
function snapToNearest(value, steps) {
|
| 159 |
+
return steps.reduce((prev, curr) => Math.abs(curr - value) < Math.abs(prev - value) ? curr : prev);
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// Create placeholder texture (smiley face)
|
| 163 |
+
function createPlaceholderTexture() {
|
| 164 |
+
const canvas = document.createElement('canvas');
|
| 165 |
+
canvas.width = 256;
|
| 166 |
+
canvas.height = 256;
|
| 167 |
+
const ctx = canvas.getContext('2d');
|
| 168 |
+
ctx.fillStyle = '#3a3a4a';
|
| 169 |
+
ctx.fillRect(0, 0, 256, 256);
|
| 170 |
+
ctx.fillStyle = '#ffcc99';
|
| 171 |
+
ctx.beginPath();
|
| 172 |
+
ctx.arc(128, 128, 80, 0, Math.PI * 2);
|
| 173 |
+
ctx.fill();
|
| 174 |
+
ctx.fillStyle = '#333';
|
| 175 |
+
ctx.beginPath();
|
| 176 |
+
ctx.arc(100, 110, 10, 0, Math.PI * 2);
|
| 177 |
+
ctx.arc(156, 110, 10, 0, Math.PI * 2);
|
| 178 |
+
ctx.fill();
|
| 179 |
+
ctx.strokeStyle = '#333';
|
| 180 |
+
ctx.lineWidth = 3;
|
| 181 |
+
ctx.beginPath();
|
| 182 |
+
ctx.arc(128, 130, 35, 0.2, Math.PI - 0.2);
|
| 183 |
+
ctx.stroke();
|
| 184 |
+
return new THREE.CanvasTexture(canvas);
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
// Target image plane
|
| 188 |
+
let currentTexture = createPlaceholderTexture();
|
| 189 |
+
const planeMaterial = new THREE.MeshBasicMaterial({ map: currentTexture, side: THREE.DoubleSide });
|
| 190 |
+
let targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
|
| 191 |
+
targetPlane.position.copy(CENTER);
|
| 192 |
+
scene.add(targetPlane);
|
| 193 |
+
|
| 194 |
+
// Function to update texture from image URL
|
| 195 |
+
function updateTextureFromUrl(url) {
|
| 196 |
+
if (!url) {
|
| 197 |
+
// Reset to placeholder
|
| 198 |
+
planeMaterial.map = createPlaceholderTexture();
|
| 199 |
+
planeMaterial.needsUpdate = true;
|
| 200 |
+
// Reset plane to square
|
| 201 |
+
scene.remove(targetPlane);
|
| 202 |
+
targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
|
| 203 |
+
targetPlane.position.copy(CENTER);
|
| 204 |
+
scene.add(targetPlane);
|
| 205 |
+
return;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
const loader = new THREE.TextureLoader();
|
| 209 |
+
loader.crossOrigin = 'anonymous';
|
| 210 |
+
loader.load(url, (texture) => {
|
| 211 |
+
texture.minFilter = THREE.LinearFilter;
|
| 212 |
+
texture.magFilter = THREE.LinearFilter;
|
| 213 |
+
planeMaterial.map = texture;
|
| 214 |
+
planeMaterial.needsUpdate = true;
|
| 215 |
+
|
| 216 |
+
// Adjust plane aspect ratio to match image
|
| 217 |
+
const img = texture.image;
|
| 218 |
+
if (img && img.width && img.height) {
|
| 219 |
+
const aspect = img.width / img.height;
|
| 220 |
+
const maxSize = 1.5;
|
| 221 |
+
let planeWidth, planeHeight;
|
| 222 |
+
if (aspect > 1) {
|
| 223 |
+
planeWidth = maxSize;
|
| 224 |
+
planeHeight = maxSize / aspect;
|
| 225 |
+
} else {
|
| 226 |
+
planeHeight = maxSize;
|
| 227 |
+
planeWidth = maxSize * aspect;
|
| 228 |
+
}
|
| 229 |
+
scene.remove(targetPlane);
|
| 230 |
+
targetPlane = new THREE.Mesh(
|
| 231 |
+
new THREE.PlaneGeometry(planeWidth, planeHeight),
|
| 232 |
+
planeMaterial
|
| 233 |
+
);
|
| 234 |
+
targetPlane.position.copy(CENTER);
|
| 235 |
+
scene.add(targetPlane);
|
| 236 |
+
}
|
| 237 |
+
}, undefined, (err) => {
|
| 238 |
+
console.error('Failed to load texture:', err);
|
| 239 |
+
});
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
// Check for initial imageUrl
|
| 243 |
+
if (props.imageUrl) {
|
| 244 |
+
updateTextureFromUrl(props.imageUrl);
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
// Camera model
|
| 248 |
+
const cameraGroup = new THREE.Group();
|
| 249 |
+
const bodyMat = new THREE.MeshStandardMaterial({ color: 0x6699cc, metalness: 0.5, roughness: 0.3 });
|
| 250 |
+
const body = new THREE.Mesh(new THREE.BoxGeometry(0.3, 0.22, 0.38), bodyMat);
|
| 251 |
+
cameraGroup.add(body);
|
| 252 |
+
const lens = new THREE.Mesh(
|
| 253 |
+
new THREE.CylinderGeometry(0.09, 0.11, 0.18, 16),
|
| 254 |
+
new THREE.MeshStandardMaterial({ color: 0x6699cc, metalness: 0.5, roughness: 0.3 })
|
| 255 |
+
);
|
| 256 |
+
lens.rotation.x = Math.PI / 2;
|
| 257 |
+
lens.position.z = 0.26;
|
| 258 |
+
cameraGroup.add(lens);
|
| 259 |
+
scene.add(cameraGroup);
|
| 260 |
+
|
| 261 |
+
// GREEN: Azimuth ring
|
| 262 |
+
const azimuthRing = new THREE.Mesh(
|
| 263 |
+
new THREE.TorusGeometry(AZIMUTH_RADIUS, 0.04, 16, 64),
|
| 264 |
+
new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.3 })
|
| 265 |
+
);
|
| 266 |
+
azimuthRing.rotation.x = Math.PI / 2;
|
| 267 |
+
azimuthRing.position.y = 0.05;
|
| 268 |
+
scene.add(azimuthRing);
|
| 269 |
+
|
| 270 |
+
const azimuthHandle = new THREE.Mesh(
|
| 271 |
+
new THREE.SphereGeometry(0.18, 16, 16),
|
| 272 |
+
new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.5 })
|
| 273 |
+
);
|
| 274 |
+
azimuthHandle.userData.type = 'azimuth';
|
| 275 |
+
scene.add(azimuthHandle);
|
| 276 |
+
|
| 277 |
+
// PINK: Elevation arc
|
| 278 |
+
const arcPoints = [];
|
| 279 |
+
for (let i = 0; i <= 32; i++) {
|
| 280 |
+
const angle = THREE.MathUtils.degToRad(-30 + (90 * i / 32));
|
| 281 |
+
arcPoints.push(new THREE.Vector3(-0.8, ELEVATION_RADIUS * Math.sin(angle) + CENTER.y, ELEVATION_RADIUS * Math.cos(angle)));
|
| 282 |
+
}
|
| 283 |
+
const arcCurve = new THREE.CatmullRomCurve3(arcPoints);
|
| 284 |
+
const elevationArc = new THREE.Mesh(
|
| 285 |
+
new THREE.TubeGeometry(arcCurve, 32, 0.04, 8, false),
|
| 286 |
+
new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.3 })
|
| 287 |
+
);
|
| 288 |
+
scene.add(elevationArc);
|
| 289 |
+
|
| 290 |
+
const elevationHandle = new THREE.Mesh(
|
| 291 |
+
new THREE.SphereGeometry(0.18, 16, 16),
|
| 292 |
+
new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.5 })
|
| 293 |
+
);
|
| 294 |
+
elevationHandle.userData.type = 'elevation';
|
| 295 |
+
scene.add(elevationHandle);
|
| 296 |
+
|
| 297 |
+
// ORANGE: Distance line & handle
|
| 298 |
+
const distanceLineGeo = new THREE.BufferGeometry();
|
| 299 |
+
const distanceLine = new THREE.Line(distanceLineGeo, new THREE.LineBasicMaterial({ color: 0xffa500 }));
|
| 300 |
+
scene.add(distanceLine);
|
| 301 |
+
|
| 302 |
+
const distanceHandle = new THREE.Mesh(
|
| 303 |
+
new THREE.SphereGeometry(0.18, 16, 16),
|
| 304 |
+
new THREE.MeshStandardMaterial({ color: 0xffa500, emissive: 0xffa500, emissiveIntensity: 0.5 })
|
| 305 |
+
);
|
| 306 |
+
distanceHandle.userData.type = 'distance';
|
| 307 |
+
scene.add(distanceHandle);
|
| 308 |
+
|
| 309 |
+
function updatePositions() {
|
| 310 |
+
const distance = BASE_DISTANCE * distanceFactor;
|
| 311 |
+
const azRad = THREE.MathUtils.degToRad(azimuthAngle);
|
| 312 |
+
const elRad = THREE.MathUtils.degToRad(elevationAngle);
|
| 313 |
+
|
| 314 |
+
const camX = distance * Math.sin(azRad) * Math.cos(elRad);
|
| 315 |
+
const camY = distance * Math.sin(elRad) + CENTER.y;
|
| 316 |
+
const camZ = distance * Math.cos(azRad) * Math.cos(elRad);
|
| 317 |
+
|
| 318 |
+
cameraGroup.position.set(camX, camY, camZ);
|
| 319 |
+
cameraGroup.lookAt(CENTER);
|
| 320 |
+
|
| 321 |
+
azimuthHandle.position.set(AZIMUTH_RADIUS * Math.sin(azRad), 0.05, AZIMUTH_RADIUS * Math.cos(azRad));
|
| 322 |
+
elevationHandle.position.set(-0.8, ELEVATION_RADIUS * Math.sin(elRad) + CENTER.y, ELEVATION_RADIUS * Math.cos(elRad));
|
| 323 |
+
|
| 324 |
+
const orangeDist = distance - 0.5;
|
| 325 |
+
distanceHandle.position.set(
|
| 326 |
+
orangeDist * Math.sin(azRad) * Math.cos(elRad),
|
| 327 |
+
orangeDist * Math.sin(elRad) + CENTER.y,
|
| 328 |
+
orangeDist * Math.cos(azRad) * Math.cos(elRad)
|
| 329 |
+
);
|
| 330 |
+
distanceLineGeo.setFromPoints([cameraGroup.position.clone(), CENTER.clone()]);
|
| 331 |
+
|
| 332 |
+
// Update prompt
|
| 333 |
+
const azSnap = snapToNearest(azimuthAngle, azimuthSteps);
|
| 334 |
+
const elSnap = snapToNearest(elevationAngle, elevationSteps);
|
| 335 |
+
const distSnap = snapToNearest(distanceFactor, distanceSteps);
|
| 336 |
+
const distKey = distSnap === 1 ? '1' : distSnap.toFixed(1);
|
| 337 |
+
const prompt = '<sks> ' + azimuthNames[azSnap] + ' ' + elevationNames[String(elSnap)] + ' ' + distanceNames[distKey];
|
| 338 |
+
promptOverlay.textContent = prompt;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
function updatePropsAndTrigger() {
|
| 342 |
+
const azSnap = snapToNearest(azimuthAngle, azimuthSteps);
|
| 343 |
+
const elSnap = snapToNearest(elevationAngle, elevationSteps);
|
| 344 |
+
const distSnap = snapToNearest(distanceFactor, distanceSteps);
|
| 345 |
+
|
| 346 |
+
props.value = { azimuth: azSnap, elevation: elSnap, distance: distSnap };
|
| 347 |
+
trigger('change', props.value);
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
// Raycasting
|
| 351 |
+
const raycaster = new THREE.Raycaster();
|
| 352 |
+
const mouse = new THREE.Vector2();
|
| 353 |
+
let isDragging = false;
|
| 354 |
+
let dragTarget = null;
|
| 355 |
+
let dragStartMouse = new THREE.Vector2();
|
| 356 |
+
let dragStartDistance = 1.0;
|
| 357 |
+
const intersection = new THREE.Vector3();
|
| 358 |
+
|
| 359 |
+
const canvas = renderer.domElement;
|
| 360 |
+
|
| 361 |
+
canvas.addEventListener('mousedown', (e) => {
|
| 362 |
+
const rect = canvas.getBoundingClientRect();
|
| 363 |
+
mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
|
| 364 |
+
mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
|
| 365 |
+
|
| 366 |
+
raycaster.setFromCamera(mouse, camera);
|
| 367 |
+
const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle, distanceHandle]);
|
| 368 |
+
|
| 369 |
+
if (intersects.length > 0) {
|
| 370 |
+
isDragging = true;
|
| 371 |
+
dragTarget = intersects[0].object;
|
| 372 |
+
dragTarget.material.emissiveIntensity = 1.0;
|
| 373 |
+
dragTarget.scale.setScalar(1.3);
|
| 374 |
+
dragStartMouse.copy(mouse);
|
| 375 |
+
dragStartDistance = distanceFactor;
|
| 376 |
+
canvas.style.cursor = 'grabbing';
|
| 377 |
+
}
|
| 378 |
+
});
|
| 379 |
+
|
| 380 |
+
canvas.addEventListener('mousemove', (e) => {
|
| 381 |
+
const rect = canvas.getBoundingClientRect();
|
| 382 |
+
mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
|
| 383 |
+
mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
|
| 384 |
+
|
| 385 |
+
if (isDragging && dragTarget) {
|
| 386 |
+
raycaster.setFromCamera(mouse, camera);
|
| 387 |
+
|
| 388 |
+
if (dragTarget.userData.type === 'azimuth') {
|
| 389 |
+
const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
|
| 390 |
+
if (raycaster.ray.intersectPlane(plane, intersection)) {
|
| 391 |
+
azimuthAngle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
|
| 392 |
+
if (azimuthAngle < 0) azimuthAngle += 360;
|
| 393 |
+
}
|
| 394 |
+
} else if (dragTarget.userData.type === 'elevation') {
|
| 395 |
+
const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), -0.8);
|
| 396 |
+
if (raycaster.ray.intersectPlane(plane, intersection)) {
|
| 397 |
+
const relY = intersection.y - CENTER.y;
|
| 398 |
+
const relZ = intersection.z;
|
| 399 |
+
elevationAngle = THREE.MathUtils.clamp(THREE.MathUtils.radToDeg(Math.atan2(relY, relZ)), -30, 60);
|
| 400 |
+
}
|
| 401 |
+
} else if (dragTarget.userData.type === 'distance') {
|
| 402 |
+
const deltaY = mouse.y - dragStartMouse.y;
|
| 403 |
+
distanceFactor = THREE.MathUtils.clamp(dragStartDistance - deltaY * 1.5, 0.6, 1.4);
|
| 404 |
+
}
|
| 405 |
+
updatePositions();
|
| 406 |
+
} else {
|
| 407 |
+
raycaster.setFromCamera(mouse, camera);
|
| 408 |
+
const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle, distanceHandle]);
|
| 409 |
+
[azimuthHandle, elevationHandle, distanceHandle].forEach(h => {
|
| 410 |
+
h.material.emissiveIntensity = 0.5;
|
| 411 |
+
h.scale.setScalar(1);
|
| 412 |
+
});
|
| 413 |
+
if (intersects.length > 0) {
|
| 414 |
+
intersects[0].object.material.emissiveIntensity = 0.8;
|
| 415 |
+
intersects[0].object.scale.setScalar(1.1);
|
| 416 |
+
canvas.style.cursor = 'grab';
|
| 417 |
+
} else {
|
| 418 |
+
canvas.style.cursor = 'default';
|
| 419 |
+
}
|
| 420 |
+
}
|
| 421 |
+
});
|
| 422 |
+
|
| 423 |
+
const onMouseUp = () => {
|
| 424 |
+
if (dragTarget) {
|
| 425 |
+
dragTarget.material.emissiveIntensity = 0.5;
|
| 426 |
+
dragTarget.scale.setScalar(1);
|
| 427 |
+
|
| 428 |
+
// Snap and animate
|
| 429 |
+
const targetAz = snapToNearest(azimuthAngle, azimuthSteps);
|
| 430 |
+
const targetEl = snapToNearest(elevationAngle, elevationSteps);
|
| 431 |
+
const targetDist = snapToNearest(distanceFactor, distanceSteps);
|
| 432 |
+
|
| 433 |
+
const startAz = azimuthAngle, startEl = elevationAngle, startDist = distanceFactor;
|
| 434 |
+
const startTime = Date.now();
|
| 435 |
+
|
| 436 |
+
function animateSnap() {
|
| 437 |
+
const t = Math.min((Date.now() - startTime) / 200, 1);
|
| 438 |
+
const ease = 1 - Math.pow(1 - t, 3);
|
| 439 |
+
|
| 440 |
+
let azDiff = targetAz - startAz;
|
| 441 |
+
if (azDiff > 180) azDiff -= 360;
|
| 442 |
+
if (azDiff < -180) azDiff += 360;
|
| 443 |
+
azimuthAngle = startAz + azDiff * ease;
|
| 444 |
+
if (azimuthAngle < 0) azimuthAngle += 360;
|
| 445 |
+
if (azimuthAngle >= 360) azimuthAngle -= 360;
|
| 446 |
+
|
| 447 |
+
elevationAngle = startEl + (targetEl - startEl) * ease;
|
| 448 |
+
distanceFactor = startDist + (targetDist - startDist) * ease;
|
| 449 |
+
|
| 450 |
+
updatePositions();
|
| 451 |
+
if (t < 1) requestAnimationFrame(animateSnap);
|
| 452 |
+
else updatePropsAndTrigger();
|
| 453 |
+
}
|
| 454 |
+
animateSnap();
|
| 455 |
+
}
|
| 456 |
+
isDragging = false;
|
| 457 |
+
dragTarget = null;
|
| 458 |
+
canvas.style.cursor = 'default';
|
| 459 |
+
};
|
| 460 |
+
|
| 461 |
+
canvas.addEventListener('mouseup', onMouseUp);
|
| 462 |
+
canvas.addEventListener('mouseleave', onMouseUp);
|
| 463 |
+
|
| 464 |
+
// Touch support for mobile
|
| 465 |
+
canvas.addEventListener('touchstart', (e) => {
|
| 466 |
+
e.preventDefault();
|
| 467 |
+
const touch = e.touches[0];
|
| 468 |
+
const rect = canvas.getBoundingClientRect();
|
| 469 |
+
mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
|
| 470 |
+
mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
|
| 471 |
+
|
| 472 |
+
raycaster.setFromCamera(mouse, camera);
|
| 473 |
+
const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle, distanceHandle]);
|
| 474 |
+
|
| 475 |
+
if (intersects.length > 0) {
|
| 476 |
+
isDragging = true;
|
| 477 |
+
dragTarget = intersects[0].object;
|
| 478 |
+
dragTarget.material.emissiveIntensity = 1.0;
|
| 479 |
+
dragTarget.scale.setScalar(1.3);
|
| 480 |
+
dragStartMouse.copy(mouse);
|
| 481 |
+
dragStartDistance = distanceFactor;
|
| 482 |
+
}
|
| 483 |
+
}, { passive: false });
|
| 484 |
+
|
| 485 |
+
canvas.addEventListener('touchmove', (e) => {
|
| 486 |
+
e.preventDefault();
|
| 487 |
+
const touch = e.touches[0];
|
| 488 |
+
const rect = canvas.getBoundingClientRect();
|
| 489 |
+
mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
|
| 490 |
+
mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
|
| 491 |
+
|
| 492 |
+
if (isDragging && dragTarget) {
|
| 493 |
+
raycaster.setFromCamera(mouse, camera);
|
| 494 |
+
|
| 495 |
+
if (dragTarget.userData.type === 'azimuth') {
|
| 496 |
+
const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
|
| 497 |
+
if (raycaster.ray.intersectPlane(plane, intersection)) {
|
| 498 |
+
azimuthAngle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
|
| 499 |
+
if (azimuthAngle < 0) azimuthAngle += 360;
|
| 500 |
+
}
|
| 501 |
+
} else if (dragTarget.userData.type === 'elevation') {
|
| 502 |
+
const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), -0.8);
|
| 503 |
+
if (raycaster.ray.intersectPlane(plane, intersection)) {
|
| 504 |
+
const relY = intersection.y - CENTER.y;
|
| 505 |
+
const relZ = intersection.z;
|
| 506 |
+
elevationAngle = THREE.MathUtils.clamp(THREE.MathUtils.radToDeg(Math.atan2(relY, relZ)), -30, 60);
|
| 507 |
+
}
|
| 508 |
+
} else if (dragTarget.userData.type === 'distance') {
|
| 509 |
+
const deltaY = mouse.y - dragStartMouse.y;
|
| 510 |
+
distanceFactor = THREE.MathUtils.clamp(dragStartDistance - deltaY * 1.5, 0.6, 1.4);
|
| 511 |
+
}
|
| 512 |
+
updatePositions();
|
| 513 |
+
}
|
| 514 |
+
}, { passive: false });
|
| 515 |
+
|
| 516 |
+
canvas.addEventListener('touchend', (e) => {
|
| 517 |
+
e.preventDefault();
|
| 518 |
+
onMouseUp();
|
| 519 |
+
}, { passive: false });
|
| 520 |
+
|
| 521 |
+
canvas.addEventListener('touchcancel', (e) => {
|
| 522 |
+
e.preventDefault();
|
| 523 |
+
onMouseUp();
|
| 524 |
+
}, { passive: false });
|
| 525 |
+
|
| 526 |
+
// Initial update
|
| 527 |
+
updatePositions();
|
| 528 |
+
|
| 529 |
+
// Render loop
|
| 530 |
+
function render() {
|
| 531 |
+
requestAnimationFrame(render);
|
| 532 |
+
renderer.render(scene, camera);
|
| 533 |
+
}
|
| 534 |
+
render();
|
| 535 |
+
|
| 536 |
+
// Handle resize
|
| 537 |
+
new ResizeObserver(() => {
|
| 538 |
+
camera.aspect = wrapper.clientWidth / wrapper.clientHeight;
|
| 539 |
+
camera.updateProjectionMatrix();
|
| 540 |
+
renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
|
| 541 |
+
}).observe(wrapper);
|
| 542 |
+
|
| 543 |
+
// Store update functions for external calls
|
| 544 |
+
wrapper._updateFromProps = (newVal) => {
|
| 545 |
+
if (newVal && typeof newVal === 'object') {
|
| 546 |
+
azimuthAngle = newVal.azimuth ?? azimuthAngle;
|
| 547 |
+
elevationAngle = newVal.elevation ?? elevationAngle;
|
| 548 |
+
distanceFactor = newVal.distance ?? distanceFactor;
|
| 549 |
+
updatePositions();
|
| 550 |
+
}
|
| 551 |
+
};
|
| 552 |
+
|
| 553 |
+
wrapper._updateTexture = updateTextureFromUrl;
|
| 554 |
+
|
| 555 |
+
// Watch for prop changes (imageUrl and value)
|
| 556 |
+
let lastImageUrl = props.imageUrl;
|
| 557 |
+
let lastValue = JSON.stringify(props.value);
|
| 558 |
+
setInterval(() => {
|
| 559 |
+
// Check imageUrl changes
|
| 560 |
+
if (props.imageUrl !== lastImageUrl) {
|
| 561 |
+
lastImageUrl = props.imageUrl;
|
| 562 |
+
updateTextureFromUrl(props.imageUrl);
|
| 563 |
+
}
|
| 564 |
+
// Check value changes (from sliders)
|
| 565 |
+
const currentValue = JSON.stringify(props.value);
|
| 566 |
+
if (currentValue !== lastValue) {
|
| 567 |
+
lastValue = currentValue;
|
| 568 |
+
if (props.value && typeof props.value === 'object') {
|
| 569 |
+
azimuthAngle = props.value.azimuth ?? azimuthAngle;
|
| 570 |
+
elevationAngle = props.value.elevation ?? elevationAngle;
|
| 571 |
+
distanceFactor = props.value.distance ?? distanceFactor;
|
| 572 |
+
updatePositions();
|
| 573 |
+
}
|
| 574 |
+
}
|
| 575 |
+
}, 100);
|
| 576 |
+
};
|
| 577 |
+
|
| 578 |
+
initScene();
|
| 579 |
+
})();
|
| 580 |
+
"""
|
| 581 |
+
|
| 582 |
+
super().__init__(
|
| 583 |
+
value=value,
|
| 584 |
+
html_template=html_template,
|
| 585 |
+
js_on_load=js_on_load,
|
| 586 |
+
imageUrl=imageUrl,
|
| 587 |
+
**kwargs
|
| 588 |
+
)
|
| 589 |
+
|
camera_control_ui.pyi
ADDED
|
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
# Azimuth mappings (8 positions)
|
| 4 |
+
AZIMUTH_MAP = {
|
| 5 |
+
0: "front view",
|
| 6 |
+
45: "front-right quarter view",
|
| 7 |
+
90: "right side view",
|
| 8 |
+
135: "back-right quarter view",
|
| 9 |
+
180: "back view",
|
| 10 |
+
225: "back-left quarter view",
|
| 11 |
+
270: "left side view",
|
| 12 |
+
315: "front-left quarter view"
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
# Elevation mappings (4 positions)
|
| 16 |
+
ELEVATION_MAP = {
|
| 17 |
+
-30: "low-angle shot",
|
| 18 |
+
0: "eye-level shot",
|
| 19 |
+
30: "elevated shot",
|
| 20 |
+
60: "high-angle shot"
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
# Distance mappings (3 positions)
|
| 24 |
+
DISTANCE_MAP = {
|
| 25 |
+
0.6: "close-up",
|
| 26 |
+
1.0: "medium shot",
|
| 27 |
+
1.8: "wide shot"
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def snap_to_nearest(value, options):
|
| 32 |
+
"""Snap a value to the nearest option in a list."""
|
| 33 |
+
return min(options, key=lambda x: abs(x - value))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_camera_prompt(azimuth: float, elevation: float, distance: float) -> str:
|
| 37 |
+
"""
|
| 38 |
+
Build a camera prompt from azimuth, elevation, and distance values.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
azimuth: Horizontal rotation in degrees (0-360)
|
| 42 |
+
elevation: Vertical angle in degrees (-30 to 60)
|
| 43 |
+
distance: Distance factor (0.6 to 1.8)
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Formatted prompt string for the LoRA
|
| 47 |
+
"""
|
| 48 |
+
# Snap to nearest valid values
|
| 49 |
+
azimuth_snapped = snap_to_nearest(azimuth, list(AZIMUTH_MAP.keys()))
|
| 50 |
+
elevation_snapped = snap_to_nearest(elevation, list(ELEVATION_MAP.keys()))
|
| 51 |
+
distance_snapped = snap_to_nearest(distance, list(DISTANCE_MAP.keys()))
|
| 52 |
+
|
| 53 |
+
azimuth_name = AZIMUTH_MAP[azimuth_snapped]
|
| 54 |
+
elevation_name = ELEVATION_MAP[elevation_snapped]
|
| 55 |
+
distance_name = DISTANCE_MAP[distance_snapped]
|
| 56 |
+
|
| 57 |
+
return f"<sks> {azimuth_name} {elevation_name} {distance_name}"
|
| 58 |
+
|
| 59 |
+
from gradio.events import Dependency
|
| 60 |
+
|
| 61 |
+
# --- 3D Camera Control Component ---
|
| 62 |
+
class CameraControl3D(gr.HTML):
|
| 63 |
+
"""
|
| 64 |
+
A 3D camera control component using Three.js.
|
| 65 |
+
Outputs: { azimuth: number, elevation: number, distance: number }
|
| 66 |
+
Accepts imageUrl prop to display user's uploaded image on the plane.
|
| 67 |
+
"""
|
| 68 |
+
def __init__(self, value=None, imageUrl=None, **kwargs):
|
| 69 |
+
if value is None:
|
| 70 |
+
value = {"azimuth": 0, "elevation": 0, "distance": 1.0}
|
| 71 |
+
|
| 72 |
+
html_template = """
|
| 73 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
|
| 74 |
+
<div id="camera-control-wrapper" style="width: 100%; height: 450px; position: relative; background: #1a1a1a; border-radius: 12px; overflow: hidden;">
|
| 75 |
+
<div id="prompt-overlay" style="position: absolute; bottom: 10px; left: 50%; transform: translateX(-50%); background: rgba(0,0,0,0.8); padding: 8px 16px; border-radius: 8px; font-family: monospace; font-size: 12px; color: #00ff88; white-space: nowrap; z-index: 10;"></div>
|
| 76 |
+
</div>
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
js_on_load = """
|
| 80 |
+
(() => {
|
| 81 |
+
const wrapper = element.querySelector('#camera-control-wrapper');
|
| 82 |
+
const promptOverlay = element.querySelector('#prompt-overlay');
|
| 83 |
+
|
| 84 |
+
// Wait for THREE to load
|
| 85 |
+
const initScene = () => {
|
| 86 |
+
if (typeof THREE === 'undefined') {
|
| 87 |
+
setTimeout(initScene, 100);
|
| 88 |
+
return;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
// Scene setup
|
| 92 |
+
const scene = new THREE.Scene();
|
| 93 |
+
scene.background = new THREE.Color(0x1a1a1a);
|
| 94 |
+
|
| 95 |
+
const camera = new THREE.PerspectiveCamera(50, wrapper.clientWidth / wrapper.clientHeight, 0.1, 1000);
|
| 96 |
+
camera.position.set(4.5, 3, 4.5);
|
| 97 |
+
camera.lookAt(0, 0.75, 0);
|
| 98 |
+
|
| 99 |
+
const renderer = new THREE.WebGLRenderer({ antialias: true });
|
| 100 |
+
renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
|
| 101 |
+
renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
|
| 102 |
+
wrapper.insertBefore(renderer.domElement, promptOverlay);
|
| 103 |
+
|
| 104 |
+
// Lighting
|
| 105 |
+
scene.add(new THREE.AmbientLight(0xffffff, 0.6));
|
| 106 |
+
const dirLight = new THREE.DirectionalLight(0xffffff, 0.6);
|
| 107 |
+
dirLight.position.set(5, 10, 5);
|
| 108 |
+
scene.add(dirLight);
|
| 109 |
+
|
| 110 |
+
// Grid
|
| 111 |
+
scene.add(new THREE.GridHelper(8, 16, 0x333333, 0x222222));
|
| 112 |
+
|
| 113 |
+
// Constants - reduced distances for tighter framing
|
| 114 |
+
const CENTER = new THREE.Vector3(0, 0.75, 0);
|
| 115 |
+
const BASE_DISTANCE = 1.6;
|
| 116 |
+
const AZIMUTH_RADIUS = 2.4;
|
| 117 |
+
const ELEVATION_RADIUS = 1.8;
|
| 118 |
+
|
| 119 |
+
// State
|
| 120 |
+
let azimuthAngle = props.value?.azimuth || 0;
|
| 121 |
+
let elevationAngle = props.value?.elevation || 0;
|
| 122 |
+
let distanceFactor = props.value?.distance || 1.0;
|
| 123 |
+
|
| 124 |
+
// Mappings - reduced wide shot multiplier
|
| 125 |
+
const azimuthSteps = [0, 45, 90, 135, 180, 225, 270, 315];
|
| 126 |
+
const elevationSteps = [-30, 0, 30, 60];
|
| 127 |
+
const distanceSteps = [0.6, 1.0, 1.4];
|
| 128 |
+
|
| 129 |
+
const azimuthNames = {
|
| 130 |
+
0: 'front view', 45: 'front-right quarter view', 90: 'right side view',
|
| 131 |
+
135: 'back-right quarter view', 180: 'back view', 225: 'back-left quarter view',
|
| 132 |
+
270: 'left side view', 315: 'front-left quarter view'
|
| 133 |
+
};
|
| 134 |
+
const elevationNames = { '-30': 'low-angle shot', '0': 'eye-level shot', '30': 'elevated shot', '60': 'high-angle shot' };
|
| 135 |
+
const distanceNames = { '0.6': 'close-up', '1': 'medium shot', '1.4': 'wide shot' };
|
| 136 |
+
|
| 137 |
+
function snapToNearest(value, steps) {
|
| 138 |
+
return steps.reduce((prev, curr) => Math.abs(curr - value) < Math.abs(prev - value) ? curr : prev);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
// Create placeholder texture (smiley face)
|
| 142 |
+
function createPlaceholderTexture() {
|
| 143 |
+
const canvas = document.createElement('canvas');
|
| 144 |
+
canvas.width = 256;
|
| 145 |
+
canvas.height = 256;
|
| 146 |
+
const ctx = canvas.getContext('2d');
|
| 147 |
+
ctx.fillStyle = '#3a3a4a';
|
| 148 |
+
ctx.fillRect(0, 0, 256, 256);
|
| 149 |
+
ctx.fillStyle = '#ffcc99';
|
| 150 |
+
ctx.beginPath();
|
| 151 |
+
ctx.arc(128, 128, 80, 0, Math.PI * 2);
|
| 152 |
+
ctx.fill();
|
| 153 |
+
ctx.fillStyle = '#333';
|
| 154 |
+
ctx.beginPath();
|
| 155 |
+
ctx.arc(100, 110, 10, 0, Math.PI * 2);
|
| 156 |
+
ctx.arc(156, 110, 10, 0, Math.PI * 2);
|
| 157 |
+
ctx.fill();
|
| 158 |
+
ctx.strokeStyle = '#333';
|
| 159 |
+
ctx.lineWidth = 3;
|
| 160 |
+
ctx.beginPath();
|
| 161 |
+
ctx.arc(128, 130, 35, 0.2, Math.PI - 0.2);
|
| 162 |
+
ctx.stroke();
|
| 163 |
+
return new THREE.CanvasTexture(canvas);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
// Target image plane
|
| 167 |
+
let currentTexture = createPlaceholderTexture();
|
| 168 |
+
const planeMaterial = new THREE.MeshBasicMaterial({ map: currentTexture, side: THREE.DoubleSide });
|
| 169 |
+
let targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
|
| 170 |
+
targetPlane.position.copy(CENTER);
|
| 171 |
+
scene.add(targetPlane);
|
| 172 |
+
|
| 173 |
+
// Function to update texture from image URL
|
| 174 |
+
function updateTextureFromUrl(url) {
|
| 175 |
+
if (!url) {
|
| 176 |
+
// Reset to placeholder
|
| 177 |
+
planeMaterial.map = createPlaceholderTexture();
|
| 178 |
+
planeMaterial.needsUpdate = true;
|
| 179 |
+
// Reset plane to square
|
| 180 |
+
scene.remove(targetPlane);
|
| 181 |
+
targetPlane = new THREE.Mesh(new THREE.PlaneGeometry(1.2, 1.2), planeMaterial);
|
| 182 |
+
targetPlane.position.copy(CENTER);
|
| 183 |
+
scene.add(targetPlane);
|
| 184 |
+
return;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
const loader = new THREE.TextureLoader();
|
| 188 |
+
loader.crossOrigin = 'anonymous';
|
| 189 |
+
loader.load(url, (texture) => {
|
| 190 |
+
texture.minFilter = THREE.LinearFilter;
|
| 191 |
+
texture.magFilter = THREE.LinearFilter;
|
| 192 |
+
planeMaterial.map = texture;
|
| 193 |
+
planeMaterial.needsUpdate = true;
|
| 194 |
+
|
| 195 |
+
// Adjust plane aspect ratio to match image
|
| 196 |
+
const img = texture.image;
|
| 197 |
+
if (img && img.width && img.height) {
|
| 198 |
+
const aspect = img.width / img.height;
|
| 199 |
+
const maxSize = 1.5;
|
| 200 |
+
let planeWidth, planeHeight;
|
| 201 |
+
if (aspect > 1) {
|
| 202 |
+
planeWidth = maxSize;
|
| 203 |
+
planeHeight = maxSize / aspect;
|
| 204 |
+
} else {
|
| 205 |
+
planeHeight = maxSize;
|
| 206 |
+
planeWidth = maxSize * aspect;
|
| 207 |
+
}
|
| 208 |
+
scene.remove(targetPlane);
|
| 209 |
+
targetPlane = new THREE.Mesh(
|
| 210 |
+
new THREE.PlaneGeometry(planeWidth, planeHeight),
|
| 211 |
+
planeMaterial
|
| 212 |
+
);
|
| 213 |
+
targetPlane.position.copy(CENTER);
|
| 214 |
+
scene.add(targetPlane);
|
| 215 |
+
}
|
| 216 |
+
}, undefined, (err) => {
|
| 217 |
+
console.error('Failed to load texture:', err);
|
| 218 |
+
});
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
// Check for initial imageUrl
|
| 222 |
+
if (props.imageUrl) {
|
| 223 |
+
updateTextureFromUrl(props.imageUrl);
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
// Camera model
|
| 227 |
+
const cameraGroup = new THREE.Group();
|
| 228 |
+
const bodyMat = new THREE.MeshStandardMaterial({ color: 0x6699cc, metalness: 0.5, roughness: 0.3 });
|
| 229 |
+
const body = new THREE.Mesh(new THREE.BoxGeometry(0.3, 0.22, 0.38), bodyMat);
|
| 230 |
+
cameraGroup.add(body);
|
| 231 |
+
const lens = new THREE.Mesh(
|
| 232 |
+
new THREE.CylinderGeometry(0.09, 0.11, 0.18, 16),
|
| 233 |
+
new THREE.MeshStandardMaterial({ color: 0x6699cc, metalness: 0.5, roughness: 0.3 })
|
| 234 |
+
);
|
| 235 |
+
lens.rotation.x = Math.PI / 2;
|
| 236 |
+
lens.position.z = 0.26;
|
| 237 |
+
cameraGroup.add(lens);
|
| 238 |
+
scene.add(cameraGroup);
|
| 239 |
+
|
| 240 |
+
// GREEN: Azimuth ring
|
| 241 |
+
const azimuthRing = new THREE.Mesh(
|
| 242 |
+
new THREE.TorusGeometry(AZIMUTH_RADIUS, 0.04, 16, 64),
|
| 243 |
+
new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.3 })
|
| 244 |
+
);
|
| 245 |
+
azimuthRing.rotation.x = Math.PI / 2;
|
| 246 |
+
azimuthRing.position.y = 0.05;
|
| 247 |
+
scene.add(azimuthRing);
|
| 248 |
+
|
| 249 |
+
const azimuthHandle = new THREE.Mesh(
|
| 250 |
+
new THREE.SphereGeometry(0.18, 16, 16),
|
| 251 |
+
new THREE.MeshStandardMaterial({ color: 0x00ff88, emissive: 0x00ff88, emissiveIntensity: 0.5 })
|
| 252 |
+
);
|
| 253 |
+
azimuthHandle.userData.type = 'azimuth';
|
| 254 |
+
scene.add(azimuthHandle);
|
| 255 |
+
|
| 256 |
+
// PINK: Elevation arc
|
| 257 |
+
const arcPoints = [];
|
| 258 |
+
for (let i = 0; i <= 32; i++) {
|
| 259 |
+
const angle = THREE.MathUtils.degToRad(-30 + (90 * i / 32));
|
| 260 |
+
arcPoints.push(new THREE.Vector3(-0.8, ELEVATION_RADIUS * Math.sin(angle) + CENTER.y, ELEVATION_RADIUS * Math.cos(angle)));
|
| 261 |
+
}
|
| 262 |
+
const arcCurve = new THREE.CatmullRomCurve3(arcPoints);
|
| 263 |
+
const elevationArc = new THREE.Mesh(
|
| 264 |
+
new THREE.TubeGeometry(arcCurve, 32, 0.04, 8, false),
|
| 265 |
+
new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.3 })
|
| 266 |
+
);
|
| 267 |
+
scene.add(elevationArc);
|
| 268 |
+
|
| 269 |
+
const elevationHandle = new THREE.Mesh(
|
| 270 |
+
new THREE.SphereGeometry(0.18, 16, 16),
|
| 271 |
+
new THREE.MeshStandardMaterial({ color: 0xff69b4, emissive: 0xff69b4, emissiveIntensity: 0.5 })
|
| 272 |
+
);
|
| 273 |
+
elevationHandle.userData.type = 'elevation';
|
| 274 |
+
scene.add(elevationHandle);
|
| 275 |
+
|
| 276 |
+
// ORANGE: Distance line & handle
|
| 277 |
+
const distanceLineGeo = new THREE.BufferGeometry();
|
| 278 |
+
const distanceLine = new THREE.Line(distanceLineGeo, new THREE.LineBasicMaterial({ color: 0xffa500 }));
|
| 279 |
+
scene.add(distanceLine);
|
| 280 |
+
|
| 281 |
+
const distanceHandle = new THREE.Mesh(
|
| 282 |
+
new THREE.SphereGeometry(0.18, 16, 16),
|
| 283 |
+
new THREE.MeshStandardMaterial({ color: 0xffa500, emissive: 0xffa500, emissiveIntensity: 0.5 })
|
| 284 |
+
);
|
| 285 |
+
distanceHandle.userData.type = 'distance';
|
| 286 |
+
scene.add(distanceHandle);
|
| 287 |
+
|
| 288 |
+
function updatePositions() {
|
| 289 |
+
const distance = BASE_DISTANCE * distanceFactor;
|
| 290 |
+
const azRad = THREE.MathUtils.degToRad(azimuthAngle);
|
| 291 |
+
const elRad = THREE.MathUtils.degToRad(elevationAngle);
|
| 292 |
+
|
| 293 |
+
const camX = distance * Math.sin(azRad) * Math.cos(elRad);
|
| 294 |
+
const camY = distance * Math.sin(elRad) + CENTER.y;
|
| 295 |
+
const camZ = distance * Math.cos(azRad) * Math.cos(elRad);
|
| 296 |
+
|
| 297 |
+
cameraGroup.position.set(camX, camY, camZ);
|
| 298 |
+
cameraGroup.lookAt(CENTER);
|
| 299 |
+
|
| 300 |
+
azimuthHandle.position.set(AZIMUTH_RADIUS * Math.sin(azRad), 0.05, AZIMUTH_RADIUS * Math.cos(azRad));
|
| 301 |
+
elevationHandle.position.set(-0.8, ELEVATION_RADIUS * Math.sin(elRad) + CENTER.y, ELEVATION_RADIUS * Math.cos(elRad));
|
| 302 |
+
|
| 303 |
+
const orangeDist = distance - 0.5;
|
| 304 |
+
distanceHandle.position.set(
|
| 305 |
+
orangeDist * Math.sin(azRad) * Math.cos(elRad),
|
| 306 |
+
orangeDist * Math.sin(elRad) + CENTER.y,
|
| 307 |
+
orangeDist * Math.cos(azRad) * Math.cos(elRad)
|
| 308 |
+
);
|
| 309 |
+
distanceLineGeo.setFromPoints([cameraGroup.position.clone(), CENTER.clone()]);
|
| 310 |
+
|
| 311 |
+
// Update prompt
|
| 312 |
+
const azSnap = snapToNearest(azimuthAngle, azimuthSteps);
|
| 313 |
+
const elSnap = snapToNearest(elevationAngle, elevationSteps);
|
| 314 |
+
const distSnap = snapToNearest(distanceFactor, distanceSteps);
|
| 315 |
+
const distKey = distSnap === 1 ? '1' : distSnap.toFixed(1);
|
| 316 |
+
const prompt = '<sks> ' + azimuthNames[azSnap] + ' ' + elevationNames[String(elSnap)] + ' ' + distanceNames[distKey];
|
| 317 |
+
promptOverlay.textContent = prompt;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
function updatePropsAndTrigger() {
|
| 321 |
+
const azSnap = snapToNearest(azimuthAngle, azimuthSteps);
|
| 322 |
+
const elSnap = snapToNearest(elevationAngle, elevationSteps);
|
| 323 |
+
const distSnap = snapToNearest(distanceFactor, distanceSteps);
|
| 324 |
+
|
| 325 |
+
props.value = { azimuth: azSnap, elevation: elSnap, distance: distSnap };
|
| 326 |
+
trigger('change', props.value);
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
// Raycasting
|
| 330 |
+
const raycaster = new THREE.Raycaster();
|
| 331 |
+
const mouse = new THREE.Vector2();
|
| 332 |
+
let isDragging = false;
|
| 333 |
+
let dragTarget = null;
|
| 334 |
+
let dragStartMouse = new THREE.Vector2();
|
| 335 |
+
let dragStartDistance = 1.0;
|
| 336 |
+
const intersection = new THREE.Vector3();
|
| 337 |
+
|
| 338 |
+
const canvas = renderer.domElement;
|
| 339 |
+
|
| 340 |
+
canvas.addEventListener('mousedown', (e) => {
|
| 341 |
+
const rect = canvas.getBoundingClientRect();
|
| 342 |
+
mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
|
| 343 |
+
mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
|
| 344 |
+
|
| 345 |
+
raycaster.setFromCamera(mouse, camera);
|
| 346 |
+
const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle, distanceHandle]);
|
| 347 |
+
|
| 348 |
+
if (intersects.length > 0) {
|
| 349 |
+
isDragging = true;
|
| 350 |
+
dragTarget = intersects[0].object;
|
| 351 |
+
dragTarget.material.emissiveIntensity = 1.0;
|
| 352 |
+
dragTarget.scale.setScalar(1.3);
|
| 353 |
+
dragStartMouse.copy(mouse);
|
| 354 |
+
dragStartDistance = distanceFactor;
|
| 355 |
+
canvas.style.cursor = 'grabbing';
|
| 356 |
+
}
|
| 357 |
+
});
|
| 358 |
+
|
| 359 |
+
canvas.addEventListener('mousemove', (e) => {
|
| 360 |
+
const rect = canvas.getBoundingClientRect();
|
| 361 |
+
mouse.x = ((e.clientX - rect.left) / rect.width) * 2 - 1;
|
| 362 |
+
mouse.y = -((e.clientY - rect.top) / rect.height) * 2 + 1;
|
| 363 |
+
|
| 364 |
+
if (isDragging && dragTarget) {
|
| 365 |
+
raycaster.setFromCamera(mouse, camera);
|
| 366 |
+
|
| 367 |
+
if (dragTarget.userData.type === 'azimuth') {
|
| 368 |
+
const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
|
| 369 |
+
if (raycaster.ray.intersectPlane(plane, intersection)) {
|
| 370 |
+
azimuthAngle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
|
| 371 |
+
if (azimuthAngle < 0) azimuthAngle += 360;
|
| 372 |
+
}
|
| 373 |
+
} else if (dragTarget.userData.type === 'elevation') {
|
| 374 |
+
const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), -0.8);
|
| 375 |
+
if (raycaster.ray.intersectPlane(plane, intersection)) {
|
| 376 |
+
const relY = intersection.y - CENTER.y;
|
| 377 |
+
const relZ = intersection.z;
|
| 378 |
+
elevationAngle = THREE.MathUtils.clamp(THREE.MathUtils.radToDeg(Math.atan2(relY, relZ)), -30, 60);
|
| 379 |
+
}
|
| 380 |
+
} else if (dragTarget.userData.type === 'distance') {
|
| 381 |
+
const deltaY = mouse.y - dragStartMouse.y;
|
| 382 |
+
distanceFactor = THREE.MathUtils.clamp(dragStartDistance - deltaY * 1.5, 0.6, 1.4);
|
| 383 |
+
}
|
| 384 |
+
updatePositions();
|
| 385 |
+
} else {
|
| 386 |
+
raycaster.setFromCamera(mouse, camera);
|
| 387 |
+
const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle, distanceHandle]);
|
| 388 |
+
[azimuthHandle, elevationHandle, distanceHandle].forEach(h => {
|
| 389 |
+
h.material.emissiveIntensity = 0.5;
|
| 390 |
+
h.scale.setScalar(1);
|
| 391 |
+
});
|
| 392 |
+
if (intersects.length > 0) {
|
| 393 |
+
intersects[0].object.material.emissiveIntensity = 0.8;
|
| 394 |
+
intersects[0].object.scale.setScalar(1.1);
|
| 395 |
+
canvas.style.cursor = 'grab';
|
| 396 |
+
} else {
|
| 397 |
+
canvas.style.cursor = 'default';
|
| 398 |
+
}
|
| 399 |
+
}
|
| 400 |
+
});
|
| 401 |
+
|
| 402 |
+
const onMouseUp = () => {
|
| 403 |
+
if (dragTarget) {
|
| 404 |
+
dragTarget.material.emissiveIntensity = 0.5;
|
| 405 |
+
dragTarget.scale.setScalar(1);
|
| 406 |
+
|
| 407 |
+
// Snap and animate
|
| 408 |
+
const targetAz = snapToNearest(azimuthAngle, azimuthSteps);
|
| 409 |
+
const targetEl = snapToNearest(elevationAngle, elevationSteps);
|
| 410 |
+
const targetDist = snapToNearest(distanceFactor, distanceSteps);
|
| 411 |
+
|
| 412 |
+
const startAz = azimuthAngle, startEl = elevationAngle, startDist = distanceFactor;
|
| 413 |
+
const startTime = Date.now();
|
| 414 |
+
|
| 415 |
+
function animateSnap() {
|
| 416 |
+
const t = Math.min((Date.now() - startTime) / 200, 1);
|
| 417 |
+
const ease = 1 - Math.pow(1 - t, 3);
|
| 418 |
+
|
| 419 |
+
let azDiff = targetAz - startAz;
|
| 420 |
+
if (azDiff > 180) azDiff -= 360;
|
| 421 |
+
if (azDiff < -180) azDiff += 360;
|
| 422 |
+
azimuthAngle = startAz + azDiff * ease;
|
| 423 |
+
if (azimuthAngle < 0) azimuthAngle += 360;
|
| 424 |
+
if (azimuthAngle >= 360) azimuthAngle -= 360;
|
| 425 |
+
|
| 426 |
+
elevationAngle = startEl + (targetEl - startEl) * ease;
|
| 427 |
+
distanceFactor = startDist + (targetDist - startDist) * ease;
|
| 428 |
+
|
| 429 |
+
updatePositions();
|
| 430 |
+
if (t < 1) requestAnimationFrame(animateSnap);
|
| 431 |
+
else updatePropsAndTrigger();
|
| 432 |
+
}
|
| 433 |
+
animateSnap();
|
| 434 |
+
}
|
| 435 |
+
isDragging = false;
|
| 436 |
+
dragTarget = null;
|
| 437 |
+
canvas.style.cursor = 'default';
|
| 438 |
+
};
|
| 439 |
+
|
| 440 |
+
canvas.addEventListener('mouseup', onMouseUp);
|
| 441 |
+
canvas.addEventListener('mouseleave', onMouseUp);
|
| 442 |
+
|
| 443 |
+
// Touch support for mobile
|
| 444 |
+
canvas.addEventListener('touchstart', (e) => {
|
| 445 |
+
e.preventDefault();
|
| 446 |
+
const touch = e.touches[0];
|
| 447 |
+
const rect = canvas.getBoundingClientRect();
|
| 448 |
+
mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
|
| 449 |
+
mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
|
| 450 |
+
|
| 451 |
+
raycaster.setFromCamera(mouse, camera);
|
| 452 |
+
const intersects = raycaster.intersectObjects([azimuthHandle, elevationHandle, distanceHandle]);
|
| 453 |
+
|
| 454 |
+
if (intersects.length > 0) {
|
| 455 |
+
isDragging = true;
|
| 456 |
+
dragTarget = intersects[0].object;
|
| 457 |
+
dragTarget.material.emissiveIntensity = 1.0;
|
| 458 |
+
dragTarget.scale.setScalar(1.3);
|
| 459 |
+
dragStartMouse.copy(mouse);
|
| 460 |
+
dragStartDistance = distanceFactor;
|
| 461 |
+
}
|
| 462 |
+
}, { passive: false });
|
| 463 |
+
|
| 464 |
+
canvas.addEventListener('touchmove', (e) => {
|
| 465 |
+
e.preventDefault();
|
| 466 |
+
const touch = e.touches[0];
|
| 467 |
+
const rect = canvas.getBoundingClientRect();
|
| 468 |
+
mouse.x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
|
| 469 |
+
mouse.y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
|
| 470 |
+
|
| 471 |
+
if (isDragging && dragTarget) {
|
| 472 |
+
raycaster.setFromCamera(mouse, camera);
|
| 473 |
+
|
| 474 |
+
if (dragTarget.userData.type === 'azimuth') {
|
| 475 |
+
const plane = new THREE.Plane(new THREE.Vector3(0, 1, 0), -0.05);
|
| 476 |
+
if (raycaster.ray.intersectPlane(plane, intersection)) {
|
| 477 |
+
azimuthAngle = THREE.MathUtils.radToDeg(Math.atan2(intersection.x, intersection.z));
|
| 478 |
+
if (azimuthAngle < 0) azimuthAngle += 360;
|
| 479 |
+
}
|
| 480 |
+
} else if (dragTarget.userData.type === 'elevation') {
|
| 481 |
+
const plane = new THREE.Plane(new THREE.Vector3(1, 0, 0), -0.8);
|
| 482 |
+
if (raycaster.ray.intersectPlane(plane, intersection)) {
|
| 483 |
+
const relY = intersection.y - CENTER.y;
|
| 484 |
+
const relZ = intersection.z;
|
| 485 |
+
elevationAngle = THREE.MathUtils.clamp(THREE.MathUtils.radToDeg(Math.atan2(relY, relZ)), -30, 60);
|
| 486 |
+
}
|
| 487 |
+
} else if (dragTarget.userData.type === 'distance') {
|
| 488 |
+
const deltaY = mouse.y - dragStartMouse.y;
|
| 489 |
+
distanceFactor = THREE.MathUtils.clamp(dragStartDistance - deltaY * 1.5, 0.6, 1.4);
|
| 490 |
+
}
|
| 491 |
+
updatePositions();
|
| 492 |
+
}
|
| 493 |
+
}, { passive: false });
|
| 494 |
+
|
| 495 |
+
canvas.addEventListener('touchend', (e) => {
|
| 496 |
+
e.preventDefault();
|
| 497 |
+
onMouseUp();
|
| 498 |
+
}, { passive: false });
|
| 499 |
+
|
| 500 |
+
canvas.addEventListener('touchcancel', (e) => {
|
| 501 |
+
e.preventDefault();
|
| 502 |
+
onMouseUp();
|
| 503 |
+
}, { passive: false });
|
| 504 |
+
|
| 505 |
+
// Initial update
|
| 506 |
+
updatePositions();
|
| 507 |
+
|
| 508 |
+
// Render loop
|
| 509 |
+
function render() {
|
| 510 |
+
requestAnimationFrame(render);
|
| 511 |
+
renderer.render(scene, camera);
|
| 512 |
+
}
|
| 513 |
+
render();
|
| 514 |
+
|
| 515 |
+
// Handle resize
|
| 516 |
+
new ResizeObserver(() => {
|
| 517 |
+
camera.aspect = wrapper.clientWidth / wrapper.clientHeight;
|
| 518 |
+
camera.updateProjectionMatrix();
|
| 519 |
+
renderer.setSize(wrapper.clientWidth, wrapper.clientHeight);
|
| 520 |
+
}).observe(wrapper);
|
| 521 |
+
|
| 522 |
+
// Store update functions for external calls
|
| 523 |
+
wrapper._updateFromProps = (newVal) => {
|
| 524 |
+
if (newVal && typeof newVal === 'object') {
|
| 525 |
+
azimuthAngle = newVal.azimuth ?? azimuthAngle;
|
| 526 |
+
elevationAngle = newVal.elevation ?? elevationAngle;
|
| 527 |
+
distanceFactor = newVal.distance ?? distanceFactor;
|
| 528 |
+
updatePositions();
|
| 529 |
+
}
|
| 530 |
+
};
|
| 531 |
+
|
| 532 |
+
wrapper._updateTexture = updateTextureFromUrl;
|
| 533 |
+
|
| 534 |
+
// Watch for prop changes (imageUrl and value)
|
| 535 |
+
let lastImageUrl = props.imageUrl;
|
| 536 |
+
let lastValue = JSON.stringify(props.value);
|
| 537 |
+
setInterval(() => {
|
| 538 |
+
// Check imageUrl changes
|
| 539 |
+
if (props.imageUrl !== lastImageUrl) {
|
| 540 |
+
lastImageUrl = props.imageUrl;
|
| 541 |
+
updateTextureFromUrl(props.imageUrl);
|
| 542 |
+
}
|
| 543 |
+
// Check value changes (from sliders)
|
| 544 |
+
const currentValue = JSON.stringify(props.value);
|
| 545 |
+
if (currentValue !== lastValue) {
|
| 546 |
+
lastValue = currentValue;
|
| 547 |
+
if (props.value && typeof props.value === 'object') {
|
| 548 |
+
azimuthAngle = props.value.azimuth ?? azimuthAngle;
|
| 549 |
+
elevationAngle = props.value.elevation ?? elevationAngle;
|
| 550 |
+
distanceFactor = props.value.distance ?? distanceFactor;
|
| 551 |
+
updatePositions();
|
| 552 |
+
}
|
| 553 |
+
}
|
| 554 |
+
}, 100);
|
| 555 |
+
};
|
| 556 |
+
|
| 557 |
+
initScene();
|
| 558 |
+
})();
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
super().__init__(
|
| 562 |
+
value=value,
|
| 563 |
+
html_template=html_template,
|
| 564 |
+
js_on_load=js_on_load,
|
| 565 |
+
imageUrl=imageUrl,
|
| 566 |
+
**kwargs
|
| 567 |
+
)
|
| 568 |
+
from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING
|
| 569 |
+
from gradio.blocks import Block
|
| 570 |
+
if TYPE_CHECKING:
|
| 571 |
+
from gradio.components import Timer
|
| 572 |
+
from gradio.components.base import Component
|
examples/1.jpg
ADDED
|
examples/10.jpeg
ADDED
|
examples/11.jpg
ADDED
|
examples/12.jpg
ADDED
|
examples/13.jpg
ADDED
|
Git LFS Details
|
examples/14.jpg
ADDED
|
examples/2.jpeg
ADDED
|
examples/4.jpg
ADDED
|
examples/5.jpg
ADDED
|
examples/6.jpg
ADDED
|
examples/7.jpg
ADDED
|
examples/8.jpg
ADDED
|
examples/9.jpg
ADDED
|
examples/ELS.jpg
ADDED
|
pre-requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip>=23.0.0
|
qwenimage/__init__.py
ADDED
|
File without changes
|
qwenimage/pipeline_qwenimage_edit_plus.py
ADDED
|
@@ -0,0 +1,900 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
import math
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from PIL import Image, ImageOps
|
| 23 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
| 24 |
+
|
| 25 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 26 |
+
from diffusers.loaders import QwenImageLoraLoaderMixin
|
| 27 |
+
from diffusers.models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
|
| 28 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 29 |
+
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
| 30 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 31 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 32 |
+
from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
|
| 33 |
+
|
| 34 |
+
if is_torch_xla_available():
|
| 35 |
+
import torch_xla.core.xla_model as xm
|
| 36 |
+
|
| 37 |
+
XLA_AVAILABLE = True
|
| 38 |
+
else:
|
| 39 |
+
XLA_AVAILABLE = False
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 42 |
+
|
| 43 |
+
EXAMPLE_DOC_STRING = """
|
| 44 |
+
Examples:
|
| 45 |
+
```py
|
| 46 |
+
>>> import torch
|
| 47 |
+
>>> from diffusers import QwenImageEditPlusPipeline
|
| 48 |
+
>>> from diffusers.utils import load_image
|
| 49 |
+
|
| 50 |
+
>>> pipe = QwenImageEditPlusPipeline.from_pretrained(
|
| 51 |
+
... "Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16
|
| 52 |
+
... ).to("cuda")
|
| 53 |
+
|
| 54 |
+
>>> image = load_image(
|
| 55 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
|
| 56 |
+
... ).convert("RGB")
|
| 57 |
+
|
| 58 |
+
>>> prompt = "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
|
| 59 |
+
|
| 60 |
+
>>> out = pipe(image=image, prompt=prompt, num_inference_steps=50).images[0]
|
| 61 |
+
>>> out.save("qwenimage_edit_plus.png")
|
| 62 |
+
```
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
CONDITION_IMAGE_SIZE = 384 * 384
|
| 66 |
+
VAE_IMAGE_SIZE = 1024 * 1024
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def pad_to_aspect(img: Image.Image, target_w: int, target_h: int) -> Image.Image:
|
| 70 |
+
"""Pad (letterbox) to target aspect ratio without warping."""
|
| 71 |
+
return ImageOps.pad(
|
| 72 |
+
img.convert("RGB"),
|
| 73 |
+
(int(target_w), int(target_h)),
|
| 74 |
+
method=Image.Resampling.LANCZOS,
|
| 75 |
+
color=(0, 0, 0),
|
| 76 |
+
centering=(0.5, 0.5),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def choose_condition_area(canvas_area: int, base_area: int = CONDITION_IMAGE_SIZE) -> int:
|
| 81 |
+
"""Choose a conditioning target area derived from canvas area with sensible bounds."""
|
| 82 |
+
scaled = int(canvas_area * (base_area / (1024 * 1024)))
|
| 83 |
+
return int(min(base_area, max(256 * 256, scaled)))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
| 87 |
+
def calculate_shift(
|
| 88 |
+
image_seq_len,
|
| 89 |
+
base_seq_len: int = 256,
|
| 90 |
+
max_seq_len: int = 4096,
|
| 91 |
+
base_shift: float = 0.5,
|
| 92 |
+
max_shift: float = 1.15,
|
| 93 |
+
):
|
| 94 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 95 |
+
b = base_shift - m * base_seq_len
|
| 96 |
+
mu = image_seq_len * m + b
|
| 97 |
+
return mu
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 101 |
+
def retrieve_timesteps(
|
| 102 |
+
scheduler,
|
| 103 |
+
num_inference_steps: Optional[int] = None,
|
| 104 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 105 |
+
timesteps: Optional[List[int]] = None,
|
| 106 |
+
sigmas: Optional[List[float]] = None,
|
| 107 |
+
**kwargs,
|
| 108 |
+
):
|
| 109 |
+
if timesteps is not None and sigmas is not None:
|
| 110 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one.")
|
| 111 |
+
|
| 112 |
+
if timesteps is not None:
|
| 113 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 114 |
+
if not accepts_timesteps:
|
| 115 |
+
raise ValueError(
|
| 116 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timesteps."
|
| 117 |
+
)
|
| 118 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 119 |
+
timesteps = scheduler.timesteps
|
| 120 |
+
num_inference_steps = len(timesteps)
|
| 121 |
+
|
| 122 |
+
elif sigmas is not None:
|
| 123 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 124 |
+
if not accept_sigmas:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas."
|
| 127 |
+
)
|
| 128 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 129 |
+
timesteps = scheduler.timesteps
|
| 130 |
+
num_inference_steps = len(timesteps)
|
| 131 |
+
|
| 132 |
+
else:
|
| 133 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 134 |
+
timesteps = scheduler.timesteps
|
| 135 |
+
|
| 136 |
+
return timesteps, num_inference_steps
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 140 |
+
def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"):
|
| 141 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 142 |
+
return encoder_output.latent_dist.sample(generator)
|
| 143 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 144 |
+
return encoder_output.latent_dist.mode()
|
| 145 |
+
if hasattr(encoder_output, "latents"):
|
| 146 |
+
return encoder_output.latents
|
| 147 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def calculate_dimensions(target_area: int, ratio: float, multiple: int = 32):
|
| 151 |
+
"""
|
| 152 |
+
Area-based sizing while snapping to a chosen lattice multiple.
|
| 153 |
+
Used for canvas sizing AND conditioning sizing (anti-drift).
|
| 154 |
+
"""
|
| 155 |
+
m = int(multiple) if multiple else 32
|
| 156 |
+
m = max(1, m)
|
| 157 |
+
|
| 158 |
+
width = math.sqrt(float(target_area) * float(ratio))
|
| 159 |
+
height = width / float(ratio)
|
| 160 |
+
|
| 161 |
+
width = round(width / m) * m
|
| 162 |
+
height = round(height / m) * m
|
| 163 |
+
return int(width), int(height)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# Optional: decoder VAE (Wan2x)
|
| 167 |
+
_ALT_VAE_WAN2X = None
|
| 168 |
+
|
| 169 |
+
# Track desired tiling state for the optional decoder VAE, so it stays consistent across lazy loads.
|
| 170 |
+
_ALT_VAE_WAN2X_TILING_ENABLED = False
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _set_vae_tiling(model: Any, enabled: bool) -> bool:
|
| 174 |
+
"""
|
| 175 |
+
Best-effort tiling toggle for a VAE-like module.
|
| 176 |
+
Returns True if a tiling method existed and was called, False otherwise.
|
| 177 |
+
"""
|
| 178 |
+
if model is None:
|
| 179 |
+
return False
|
| 180 |
+
try:
|
| 181 |
+
if enabled:
|
| 182 |
+
if hasattr(model, "enable_tiling"):
|
| 183 |
+
model.enable_tiling()
|
| 184 |
+
return True
|
| 185 |
+
if hasattr(model, "enable_vae_tiling"):
|
| 186 |
+
model.enable_vae_tiling()
|
| 187 |
+
return True
|
| 188 |
+
else:
|
| 189 |
+
if hasattr(model, "disable_tiling"):
|
| 190 |
+
model.disable_tiling()
|
| 191 |
+
return True
|
| 192 |
+
if hasattr(model, "disable_vae_tiling"):
|
| 193 |
+
model.disable_vae_tiling()
|
| 194 |
+
return True
|
| 195 |
+
except Exception as e:
|
| 196 |
+
# Don't hard-fail inference if tiling toggle fails for an alt decoder.
|
| 197 |
+
logger.warning(f"VAE tiling toggle failed on {type(model)}: {e}")
|
| 198 |
+
return False
|
| 199 |
+
return False
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _get_wan2x_vae(device: torch.device, dtype: torch.dtype):
|
| 204 |
+
"""
|
| 205 |
+
Decoder-only finetune that outputs 2x resolution via pixel-shuffle.
|
| 206 |
+
Lazy-loaded so it doesn't impact startup unless used.
|
| 207 |
+
"""
|
| 208 |
+
global _ALT_VAE_WAN2X, _ALT_VAE_WAN2X_TILING_ENABLED
|
| 209 |
+
if _ALT_VAE_WAN2X is None:
|
| 210 |
+
from diffusers import AutoencoderKLWan
|
| 211 |
+
|
| 212 |
+
_ALT_VAE_WAN2X = AutoencoderKLWan.from_pretrained(
|
| 213 |
+
"spacepxl/Wan2.1-VAE-upscale2x",
|
| 214 |
+
subfolder="diffusers/Wan2.1_VAE_upscale2x_imageonly_real_v1",
|
| 215 |
+
torch_dtype=dtype,
|
| 216 |
+
)
|
| 217 |
+
_ALT_VAE_WAN2X.eval()
|
| 218 |
+
|
| 219 |
+
# Apply last requested tiling immediately on first load (if supported).
|
| 220 |
+
_set_vae_tiling(_ALT_VAE_WAN2X, _ALT_VAE_WAN2X_TILING_ENABLED)
|
| 221 |
+
|
| 222 |
+
_ALT_VAE_WAN2X = _ALT_VAE_WAN2X.to(device=device, dtype=dtype)
|
| 223 |
+
|
| 224 |
+
# Re-apply after moving to device, just in case.
|
| 225 |
+
_set_vae_tiling(_ALT_VAE_WAN2X, _ALT_VAE_WAN2X_TILING_ENABLED)
|
| 226 |
+
|
| 227 |
+
return _ALT_VAE_WAN2X
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
| 231 |
+
r"""
|
| 232 |
+
The Qwen-Image-Edit pipeline for image editing.
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 236 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 241 |
+
vae: AutoencoderKLQwenImage,
|
| 242 |
+
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
| 243 |
+
tokenizer: Qwen2Tokenizer,
|
| 244 |
+
processor: Qwen2VLProcessor,
|
| 245 |
+
transformer: QwenImageTransformer2DModel,
|
| 246 |
+
):
|
| 247 |
+
super().__init__()
|
| 248 |
+
self.register_modules(
|
| 249 |
+
vae=vae,
|
| 250 |
+
text_encoder=text_encoder,
|
| 251 |
+
tokenizer=tokenizer,
|
| 252 |
+
processor=processor,
|
| 253 |
+
transformer=transformer,
|
| 254 |
+
scheduler=scheduler,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
| 258 |
+
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
|
| 259 |
+
|
| 260 |
+
# QwenImage latents are turned into 2x2 patches and packed; multiply scale-factor by patch size
|
| 261 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 262 |
+
self.tokenizer_max_length = 1024
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# Track tiling state (applies to both primary VAE and optional decoder VAE)
|
| 266 |
+
self._vae_tiling_enabled = False
|
| 267 |
+
self.prompt_template_encode = (
|
| 268 |
+
"<|im_start|>system\n"
|
| 269 |
+
"Describe the key features of the input image (color, shape, size, texture, objects, background), "
|
| 270 |
+
"then explain how the user's text instruction should alter or modify the image.\n"
|
| 271 |
+
"Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
|
| 272 |
+
"<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
| 273 |
+
)
|
| 274 |
+
self.prompt_template_encode_start_idx = 64
|
| 275 |
+
self.default_sample_size = 128
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# ------------------------------------------------------------
|
| 279 |
+
# VAE tiling control (applies to both primary VAE and optional decoder VAE)
|
| 280 |
+
# ------------------------------------------------------------
|
| 281 |
+
# Expose a stable API so app.py can call pipe.enable_vae_tiling()/disable_vae_tiling()
|
| 282 |
+
# regardless of which decoder VAE is selected at runtime.
|
| 283 |
+
|
| 284 |
+
def set_vae_tiling(self, enabled: bool) -> None:
|
| 285 |
+
global _ALT_VAE_WAN2X_TILING_ENABLED, _ALT_VAE_WAN2X
|
| 286 |
+
|
| 287 |
+
enabled = bool(enabled)
|
| 288 |
+
self._vae_tiling_enabled = enabled
|
| 289 |
+
|
| 290 |
+
# 1) Primary VAE (Qwen)
|
| 291 |
+
_set_vae_tiling(getattr(self, "vae", None), enabled)
|
| 292 |
+
|
| 293 |
+
# 2) Optional decoder VAE (Wan2x): store desired global state; apply now if already loaded.
|
| 294 |
+
_ALT_VAE_WAN2X_TILING_ENABLED = enabled
|
| 295 |
+
if _ALT_VAE_WAN2X is not None:
|
| 296 |
+
_set_vae_tiling(_ALT_VAE_WAN2X, enabled)
|
| 297 |
+
|
| 298 |
+
def enable_vae_tiling(self) -> None:
|
| 299 |
+
self.set_vae_tiling(True)
|
| 300 |
+
|
| 301 |
+
def disable_vae_tiling(self) -> None:
|
| 302 |
+
self.set_vae_tiling(False)
|
| 303 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
|
| 304 |
+
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
| 305 |
+
bool_mask = mask.bool()
|
| 306 |
+
valid_lengths = bool_mask.sum(dim=1)
|
| 307 |
+
selected = hidden_states[bool_mask]
|
| 308 |
+
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
| 309 |
+
return split_result
|
| 310 |
+
|
| 311 |
+
def _get_qwen_prompt_embeds(
|
| 312 |
+
self,
|
| 313 |
+
prompt: Union[str, List[str]] = None,
|
| 314 |
+
image: Optional[torch.Tensor] = None,
|
| 315 |
+
device: Optional[torch.device] = None,
|
| 316 |
+
dtype: Optional[torch.dtype] = None,
|
| 317 |
+
):
|
| 318 |
+
device = device or self._execution_device
|
| 319 |
+
dtype = dtype or self.text_encoder.dtype
|
| 320 |
+
|
| 321 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 322 |
+
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
|
| 323 |
+
|
| 324 |
+
if isinstance(image, list):
|
| 325 |
+
base_img_prompt = ""
|
| 326 |
+
for i, _ in enumerate(image):
|
| 327 |
+
base_img_prompt += img_prompt_template.format(i + 1)
|
| 328 |
+
elif image is not None:
|
| 329 |
+
base_img_prompt = img_prompt_template.format(1)
|
| 330 |
+
else:
|
| 331 |
+
base_img_prompt = ""
|
| 332 |
+
|
| 333 |
+
template = self.prompt_template_encode
|
| 334 |
+
drop_idx = self.prompt_template_encode_start_idx
|
| 335 |
+
txt = [template.format(base_img_prompt + e) for e in prompt]
|
| 336 |
+
|
| 337 |
+
model_inputs = self.processor(text=txt, images=image, padding=True, return_tensors="pt").to(device)
|
| 338 |
+
|
| 339 |
+
outputs = self.text_encoder(
|
| 340 |
+
input_ids=model_inputs.input_ids,
|
| 341 |
+
attention_mask=model_inputs.attention_mask,
|
| 342 |
+
pixel_values=model_inputs.pixel_values,
|
| 343 |
+
image_grid_thw=model_inputs.image_grid_thw,
|
| 344 |
+
output_hidden_states=True,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
hidden_states = outputs.hidden_states[-1]
|
| 348 |
+
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
| 349 |
+
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
| 350 |
+
|
| 351 |
+
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
| 352 |
+
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
| 353 |
+
|
| 354 |
+
prompt_embeds = torch.stack(
|
| 355 |
+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
| 356 |
+
)
|
| 357 |
+
encoder_attention_mask = torch.stack(
|
| 358 |
+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 362 |
+
return prompt_embeds, encoder_attention_mask
|
| 363 |
+
|
| 364 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
|
| 365 |
+
def encode_prompt(
|
| 366 |
+
self,
|
| 367 |
+
prompt: Union[str, List[str]],
|
| 368 |
+
image: Optional[torch.Tensor] = None,
|
| 369 |
+
device: Optional[torch.device] = None,
|
| 370 |
+
num_images_per_prompt: int = 1,
|
| 371 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 372 |
+
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 373 |
+
max_sequence_length: int = 1024,
|
| 374 |
+
):
|
| 375 |
+
device = device or self._execution_device
|
| 376 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 377 |
+
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
|
| 378 |
+
|
| 379 |
+
if prompt_embeds is None:
|
| 380 |
+
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
|
| 381 |
+
|
| 382 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 383 |
+
|
| 384 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 385 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 386 |
+
|
| 387 |
+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
| 388 |
+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
| 389 |
+
|
| 390 |
+
return prompt_embeds, prompt_embeds_mask
|
| 391 |
+
|
| 392 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
|
| 393 |
+
def check_inputs(
|
| 394 |
+
self,
|
| 395 |
+
prompt,
|
| 396 |
+
height,
|
| 397 |
+
width,
|
| 398 |
+
negative_prompt=None,
|
| 399 |
+
prompt_embeds=None,
|
| 400 |
+
negative_prompt_embeds=None,
|
| 401 |
+
prompt_embeds_mask=None,
|
| 402 |
+
negative_prompt_embeds_mask=None,
|
| 403 |
+
callback_on_step_end_tensor_inputs=None,
|
| 404 |
+
max_sequence_length=None,
|
| 405 |
+
):
|
| 406 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 407 |
+
logger.warning(
|
| 408 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. "
|
| 409 |
+
"Dimensions will be resized accordingly."
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 413 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 414 |
+
):
|
| 415 |
+
raise ValueError(
|
| 416 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
|
| 417 |
+
f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
if prompt is not None and prompt_embeds is not None:
|
| 421 |
+
raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.")
|
| 422 |
+
if prompt is None and prompt_embeds is None:
|
| 423 |
+
raise ValueError("Provide either `prompt` or `prompt_embeds`.")
|
| 424 |
+
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 425 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 426 |
+
|
| 427 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 428 |
+
raise ValueError("Cannot forward both `negative_prompt` and `negative_prompt_embeds`.")
|
| 429 |
+
|
| 430 |
+
if prompt_embeds is not None and prompt_embeds_mask is None:
|
| 431 |
+
raise ValueError("If `prompt_embeds` are provided, `prompt_embeds_mask` must also be passed.")
|
| 432 |
+
|
| 433 |
+
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
| 434 |
+
raise ValueError("If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` must also be passed.")
|
| 435 |
+
|
| 436 |
+
if max_sequence_length is not None and max_sequence_length > 1024:
|
| 437 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
| 438 |
+
|
| 439 |
+
@staticmethod
|
| 440 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 441 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 442 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 443 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 444 |
+
return latents
|
| 445 |
+
|
| 446 |
+
@staticmethod
|
| 447 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 448 |
+
batch_size, _, channels = latents.shape
|
| 449 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 450 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 451 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 452 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 453 |
+
latents = latents.reshape(batch_size, channels // 4, 1, height, width)
|
| 454 |
+
return latents
|
| 455 |
+
|
| 456 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
| 457 |
+
if isinstance(generator, list):
|
| 458 |
+
image_latents = [
|
| 459 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
|
| 460 |
+
for i in range(image.shape[0])
|
| 461 |
+
]
|
| 462 |
+
image_latents = torch.cat(image_latents, dim=0)
|
| 463 |
+
else:
|
| 464 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
| 465 |
+
|
| 466 |
+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.latent_channels, 1, 1, 1).to(
|
| 467 |
+
image_latents.device, image_latents.dtype
|
| 468 |
+
)
|
| 469 |
+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.latent_channels, 1, 1, 1).to(
|
| 470 |
+
image_latents.device, image_latents.dtype
|
| 471 |
+
)
|
| 472 |
+
image_latents = (image_latents - latents_mean) / latents_std
|
| 473 |
+
return image_latents
|
| 474 |
+
|
| 475 |
+
def prepare_latents(
|
| 476 |
+
self,
|
| 477 |
+
images,
|
| 478 |
+
batch_size,
|
| 479 |
+
num_channels_latents,
|
| 480 |
+
height,
|
| 481 |
+
width,
|
| 482 |
+
dtype,
|
| 483 |
+
device,
|
| 484 |
+
generator,
|
| 485 |
+
latents=None,
|
| 486 |
+
):
|
| 487 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 488 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 489 |
+
shape = (batch_size, 1, num_channels_latents, height, width)
|
| 490 |
+
|
| 491 |
+
image_latents = None
|
| 492 |
+
if images is not None:
|
| 493 |
+
if not isinstance(images, list):
|
| 494 |
+
images = [images]
|
| 495 |
+
all_image_latents = []
|
| 496 |
+
|
| 497 |
+
for image in images:
|
| 498 |
+
image = image.to(device=device, dtype=dtype)
|
| 499 |
+
if image.shape[1] != self.latent_channels:
|
| 500 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
| 501 |
+
else:
|
| 502 |
+
image_latents = image
|
| 503 |
+
|
| 504 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
| 505 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
| 506 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
| 507 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
| 508 |
+
raise ValueError(
|
| 509 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
image_latent_height, image_latent_width = image_latents.shape[3:]
|
| 513 |
+
image_latents = self._pack_latents(
|
| 514 |
+
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
|
| 515 |
+
)
|
| 516 |
+
all_image_latents.append(image_latents)
|
| 517 |
+
|
| 518 |
+
image_latents = torch.cat(all_image_latents, dim=1)
|
| 519 |
+
|
| 520 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 521 |
+
raise ValueError(
|
| 522 |
+
f"You passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}."
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
if latents is None:
|
| 526 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 527 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 528 |
+
else:
|
| 529 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 530 |
+
|
| 531 |
+
return latents, image_latents
|
| 532 |
+
|
| 533 |
+
@property
|
| 534 |
+
def guidance_scale(self):
|
| 535 |
+
return self._guidance_scale
|
| 536 |
+
|
| 537 |
+
@property
|
| 538 |
+
def attention_kwargs(self):
|
| 539 |
+
return self._attention_kwargs
|
| 540 |
+
|
| 541 |
+
@property
|
| 542 |
+
def num_timesteps(self):
|
| 543 |
+
return self._num_timesteps
|
| 544 |
+
|
| 545 |
+
@property
|
| 546 |
+
def current_timestep(self):
|
| 547 |
+
return self._current_timestep
|
| 548 |
+
|
| 549 |
+
@property
|
| 550 |
+
def interrupt(self):
|
| 551 |
+
return self._interrupt
|
| 552 |
+
|
| 553 |
+
@torch.no_grad()
|
| 554 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 555 |
+
def __call__(
|
| 556 |
+
self,
|
| 557 |
+
image: Optional[PipelineImageInput] = None,
|
| 558 |
+
prompt: Union[str, List[str]] = None,
|
| 559 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 560 |
+
true_cfg_scale: float = 4.0,
|
| 561 |
+
height: Optional[int] = None,
|
| 562 |
+
width: Optional[int] = None,
|
| 563 |
+
condition_area: Optional[int] = None,
|
| 564 |
+
vae_image_indices: Optional[List[int]] = None,
|
| 565 |
+
pad_to_canvas: bool = True,
|
| 566 |
+
# NEW: lattice + VAE ref override
|
| 567 |
+
resolution_multiple: Optional[int] = None,
|
| 568 |
+
vae_ref_area: Optional[int] = None,
|
| 569 |
+
vae_ref_start_index: int = 2,
|
| 570 |
+
# Optional: decoder swap
|
| 571 |
+
decoder_vae: str = "qwen", # "qwen" | "wan2x"
|
| 572 |
+
keep_decoder_2x: bool = False,
|
| 573 |
+
# standard args
|
| 574 |
+
num_inference_steps: int = 50,
|
| 575 |
+
sigmas: Optional[List[float]] = None,
|
| 576 |
+
guidance_scale: Optional[float] = None,
|
| 577 |
+
num_images_per_prompt: int = 1,
|
| 578 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 579 |
+
latents: Optional[torch.Tensor] = None,
|
| 580 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 581 |
+
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 582 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 583 |
+
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 584 |
+
output_type: Optional[str] = "pil",
|
| 585 |
+
return_dict: bool = True,
|
| 586 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 587 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 588 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 589 |
+
max_sequence_length: int = 512,
|
| 590 |
+
):
|
| 591 |
+
"""Run Qwen-Image-Edit inference.
|
| 592 |
+
|
| 593 |
+
Examples:
|
| 594 |
+
"""
|
| 595 |
+
# ---- determine input size ----
|
| 596 |
+
if isinstance(image, list):
|
| 597 |
+
image_size = image[0].size
|
| 598 |
+
else:
|
| 599 |
+
image_size = image.size
|
| 600 |
+
|
| 601 |
+
# Lattice multiple used throughout (canvas sizing + condition sizing)
|
| 602 |
+
multiple_of = int(resolution_multiple) if resolution_multiple is not None else (self.vae_scale_factor * 2)
|
| 603 |
+
multiple_of = max(1, multiple_of)
|
| 604 |
+
|
| 605 |
+
calculated_width, calculated_height = calculate_dimensions(
|
| 606 |
+
1024 * 1024, float(image_size[0]) / float(image_size[1]), multiple=multiple_of
|
| 607 |
+
)
|
| 608 |
+
height = height or calculated_height
|
| 609 |
+
width = width or calculated_width
|
| 610 |
+
|
| 611 |
+
width = (int(width) // multiple_of) * multiple_of
|
| 612 |
+
height = (int(height) // multiple_of) * multiple_of
|
| 613 |
+
|
| 614 |
+
# ---- validate ----
|
| 615 |
+
self.check_inputs(
|
| 616 |
+
prompt,
|
| 617 |
+
height,
|
| 618 |
+
width,
|
| 619 |
+
negative_prompt=negative_prompt,
|
| 620 |
+
prompt_embeds=prompt_embeds,
|
| 621 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 622 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 623 |
+
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
| 624 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 625 |
+
max_sequence_length=max_sequence_length,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
self._guidance_scale = guidance_scale
|
| 629 |
+
self._attention_kwargs = attention_kwargs
|
| 630 |
+
self._current_timestep = None
|
| 631 |
+
self._interrupt = False
|
| 632 |
+
|
| 633 |
+
# ---- call params ----
|
| 634 |
+
if prompt is not None and isinstance(prompt, str):
|
| 635 |
+
batch_size = 1
|
| 636 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 637 |
+
batch_size = len(prompt)
|
| 638 |
+
else:
|
| 639 |
+
batch_size = prompt_embeds.shape[0]
|
| 640 |
+
|
| 641 |
+
device = self._execution_device
|
| 642 |
+
|
| 643 |
+
# ---- preprocess ----
|
| 644 |
+
condition_images = None
|
| 645 |
+
vae_images = None
|
| 646 |
+
vae_image_sizes: List[tuple[int, int]] = []
|
| 647 |
+
|
| 648 |
+
# support pre-latent tensors (rare, but keep compatibility)
|
| 649 |
+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 650 |
+
if not isinstance(image, list):
|
| 651 |
+
image = [image]
|
| 652 |
+
|
| 653 |
+
canvas_area = int(width) * int(height)
|
| 654 |
+
cond_area = int(condition_area) if condition_area is not None else choose_condition_area(canvas_area)
|
| 655 |
+
|
| 656 |
+
cond_w, cond_h = calculate_dimensions(cond_area, float(width) / float(height), multiple=multiple_of)
|
| 657 |
+
|
| 658 |
+
# Optional VAE ref override sizing (applied only to indices >= vae_ref_start_index)
|
| 659 |
+
ref_w = ref_h = None
|
| 660 |
+
if vae_ref_area is not None:
|
| 661 |
+
try:
|
| 662 |
+
ref_w, ref_h = calculate_dimensions(
|
| 663 |
+
int(vae_ref_area),
|
| 664 |
+
float(width) / float(height),
|
| 665 |
+
multiple=multiple_of,
|
| 666 |
+
)
|
| 667 |
+
except Exception:
|
| 668 |
+
ref_w = ref_h = None
|
| 669 |
+
|
| 670 |
+
condition_images = []
|
| 671 |
+
vae_images = []
|
| 672 |
+
|
| 673 |
+
if vae_image_indices is None:
|
| 674 |
+
vae_image_indices = list(range(len(image)))
|
| 675 |
+
vae_set = set(int(i) for i in vae_image_indices)
|
| 676 |
+
|
| 677 |
+
for idx, img in enumerate(image):
|
| 678 |
+
pil = img.convert("RGB") if isinstance(img, Image.Image) else img
|
| 679 |
+
|
| 680 |
+
if pad_to_canvas and isinstance(pil, Image.Image):
|
| 681 |
+
pil = pad_to_aspect(pil, int(width), int(height))
|
| 682 |
+
|
| 683 |
+
# conditioning stream (always)
|
| 684 |
+
condition_images.append(self.image_processor.resize(pil, cond_h, cond_w))
|
| 685 |
+
|
| 686 |
+
# VAE stream (selective)
|
| 687 |
+
if idx in vae_set:
|
| 688 |
+
if (ref_w is not None) and (ref_h is not None) and (int(idx) >= int(vae_ref_start_index)):
|
| 689 |
+
vw, vh = int(ref_w), int(ref_h)
|
| 690 |
+
else:
|
| 691 |
+
vw, vh = int(width), int(height)
|
| 692 |
+
|
| 693 |
+
vae_image_sizes.append((vw, vh))
|
| 694 |
+
vae_images.append(self.image_processor.preprocess(pil, int(vh), int(vw)).unsqueeze(2))
|
| 695 |
+
|
| 696 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 697 |
+
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
| 698 |
+
)
|
| 699 |
+
if true_cfg_scale > 1 and not has_neg_prompt:
|
| 700 |
+
logger.warning(
|
| 701 |
+
f"true_cfg_scale={true_cfg_scale} but CFG disabled because no negative prompt was provided."
|
| 702 |
+
)
|
| 703 |
+
if true_cfg_scale <= 1 and has_neg_prompt:
|
| 704 |
+
logger.warning("negative_prompt provided but CFG disabled because true_cfg_scale <= 1")
|
| 705 |
+
|
| 706 |
+
do_true_cfg = (true_cfg_scale > 1) and has_neg_prompt
|
| 707 |
+
|
| 708 |
+
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
| 709 |
+
image=condition_images,
|
| 710 |
+
prompt=prompt,
|
| 711 |
+
prompt_embeds=prompt_embeds,
|
| 712 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 713 |
+
device=device,
|
| 714 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 715 |
+
max_sequence_length=max_sequence_length,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
if do_true_cfg:
|
| 719 |
+
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
| 720 |
+
image=condition_images,
|
| 721 |
+
prompt=negative_prompt,
|
| 722 |
+
prompt_embeds=negative_prompt_embeds,
|
| 723 |
+
prompt_embeds_mask=negative_prompt_embeds_mask,
|
| 724 |
+
device=device,
|
| 725 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 726 |
+
max_sequence_length=max_sequence_length,
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
# ---- prepare latents ----
|
| 730 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 731 |
+
latents, image_latents = self.prepare_latents(
|
| 732 |
+
vae_images,
|
| 733 |
+
batch_size * num_images_per_prompt,
|
| 734 |
+
num_channels_latents,
|
| 735 |
+
height,
|
| 736 |
+
width,
|
| 737 |
+
prompt_embeds.dtype,
|
| 738 |
+
device,
|
| 739 |
+
generator,
|
| 740 |
+
latents,
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
img_shapes = [
|
| 744 |
+
[
|
| 745 |
+
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
|
| 746 |
+
*[
|
| 747 |
+
(1, vae_h // self.vae_scale_factor // 2, vae_w // self.vae_scale_factor // 2)
|
| 748 |
+
for (vae_w, vae_h) in vae_image_sizes
|
| 749 |
+
],
|
| 750 |
+
]
|
| 751 |
+
] * batch_size
|
| 752 |
+
|
| 753 |
+
else:
|
| 754 |
+
raise ValueError(
|
| 755 |
+
"This Space pipeline expects `image` as PIL/np inputs (not pre-latents) in this setup."
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
# ---- timesteps ----
|
| 759 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 760 |
+
|
| 761 |
+
image_seq_len = latents.shape[1]
|
| 762 |
+
mu = calculate_shift(
|
| 763 |
+
image_seq_len,
|
| 764 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 765 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 766 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 767 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 768 |
+
)
|
| 769 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 770 |
+
self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 774 |
+
self._num_timesteps = len(timesteps)
|
| 775 |
+
|
| 776 |
+
# guidance-distilled models need explicit guidance input
|
| 777 |
+
if self.transformer.config.guidance_embeds and guidance_scale is None:
|
| 778 |
+
raise ValueError("guidance_scale is required for guidance-distilled model.")
|
| 779 |
+
if self.transformer.config.guidance_embeds:
|
| 780 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0])
|
| 781 |
+
else:
|
| 782 |
+
if guidance_scale is not None:
|
| 783 |
+
logger.warning("guidance_scale passed but ignored since model is not guidance-distilled.")
|
| 784 |
+
guidance = None
|
| 785 |
+
|
| 786 |
+
if self.attention_kwargs is None:
|
| 787 |
+
self._attention_kwargs = {}
|
| 788 |
+
|
| 789 |
+
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
| 790 |
+
image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
| 791 |
+
|
| 792 |
+
do_true_cfg = (
|
| 793 |
+
(true_cfg_scale > 1)
|
| 794 |
+
and (negative_prompt_embeds is not None)
|
| 795 |
+
and (negative_prompt_embeds_mask is not None)
|
| 796 |
+
)
|
| 797 |
+
if do_true_cfg:
|
| 798 |
+
negative_txt_seq_lens = negative_prompt_embeds_mask.sum(dim=1).tolist()
|
| 799 |
+
uncond_image_rotary_emb = self.transformer.pos_embed(img_shapes, negative_txt_seq_lens, device=latents.device)
|
| 800 |
+
else:
|
| 801 |
+
uncond_image_rotary_emb = None
|
| 802 |
+
|
| 803 |
+
# ---- denoise ----
|
| 804 |
+
self.scheduler.set_begin_index(0)
|
| 805 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 806 |
+
for i, t in enumerate(timesteps):
|
| 807 |
+
if self.interrupt:
|
| 808 |
+
continue
|
| 809 |
+
self._current_timestep = t
|
| 810 |
+
|
| 811 |
+
latent_model_input = latents
|
| 812 |
+
if image_latents is not None:
|
| 813 |
+
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
| 814 |
+
|
| 815 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 816 |
+
|
| 817 |
+
with self.transformer.cache_context("cond"):
|
| 818 |
+
noise_pred = self.transformer(
|
| 819 |
+
hidden_states=latent_model_input,
|
| 820 |
+
timestep=timestep / 1000,
|
| 821 |
+
guidance=guidance,
|
| 822 |
+
encoder_hidden_states_mask=prompt_embeds_mask,
|
| 823 |
+
encoder_hidden_states=prompt_embeds,
|
| 824 |
+
image_rotary_emb=image_rotary_emb,
|
| 825 |
+
attention_kwargs=self.attention_kwargs,
|
| 826 |
+
return_dict=False,
|
| 827 |
+
)[0]
|
| 828 |
+
noise_pred = noise_pred[:, : latents.size(1)]
|
| 829 |
+
|
| 830 |
+
if do_true_cfg:
|
| 831 |
+
with self.transformer.cache_context("uncond"):
|
| 832 |
+
neg_noise_pred = self.transformer(
|
| 833 |
+
hidden_states=latent_model_input,
|
| 834 |
+
timestep=timestep / 1000,
|
| 835 |
+
guidance=guidance,
|
| 836 |
+
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
| 837 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 838 |
+
image_rotary_emb=uncond_image_rotary_emb,
|
| 839 |
+
attention_kwargs=self.attention_kwargs,
|
| 840 |
+
return_dict=False,
|
| 841 |
+
)[0]
|
| 842 |
+
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
| 843 |
+
|
| 844 |
+
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 845 |
+
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
| 846 |
+
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
| 847 |
+
noise_pred = comb_pred * (cond_norm / (noise_norm + 1e-8))
|
| 848 |
+
|
| 849 |
+
latents_dtype = latents.dtype
|
| 850 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 851 |
+
if latents.dtype != latents_dtype and torch.backends.mps.is_available():
|
| 852 |
+
latents = latents.to(latents_dtype)
|
| 853 |
+
|
| 854 |
+
if callback_on_step_end is not None:
|
| 855 |
+
callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
|
| 856 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 857 |
+
latents = callback_outputs.pop("latents", latents)
|
| 858 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 859 |
+
|
| 860 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 861 |
+
progress_bar.update()
|
| 862 |
+
|
| 863 |
+
if XLA_AVAILABLE:
|
| 864 |
+
xm.mark_step()
|
| 865 |
+
|
| 866 |
+
self._current_timestep = None
|
| 867 |
+
|
| 868 |
+
# ---- decode ----
|
| 869 |
+
if output_type == "latent":
|
| 870 |
+
image_out = latents
|
| 871 |
+
else:
|
| 872 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 873 |
+
latents = latents.to(self.vae.dtype)
|
| 874 |
+
|
| 875 |
+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 876 |
+
latents.device, latents.dtype
|
| 877 |
+
)
|
| 878 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 879 |
+
latents.device, latents.dtype
|
| 880 |
+
)
|
| 881 |
+
latents = latents / latents_std + latents_mean
|
| 882 |
+
|
| 883 |
+
if decoder_vae == "wan2x":
|
| 884 |
+
alt_vae = _get_wan2x_vae(latents.device, self.vae.dtype)
|
| 885 |
+
decoder_out = alt_vae.decode(latents, return_dict=False)[0] # [B, 12, F, H, W]
|
| 886 |
+
img_2x = F.pixel_shuffle(decoder_out[:, :, 0], upscale_factor=2) # [B, 3, 2H, 2W]
|
| 887 |
+
if keep_decoder_2x:
|
| 888 |
+
decoded = img_2x
|
| 889 |
+
else:
|
| 890 |
+
decoded = F.interpolate(img_2x, size=(int(height), int(width)), mode="area")
|
| 891 |
+
else:
|
| 892 |
+
decoded = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
| 893 |
+
|
| 894 |
+
image_out = self.image_processor.postprocess(decoded, output_type=output_type)
|
| 895 |
+
|
| 896 |
+
self.maybe_free_model_hooks()
|
| 897 |
+
|
| 898 |
+
if not return_dict:
|
| 899 |
+
return (image_out,)
|
| 900 |
+
return QwenImagePipelineOutput(images=image_out)
|
qwenimage/qwen_fa3_processor.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Paired with a good language model. Thanks!
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from kernels import get_kernel
|
| 11 |
+
_k = get_kernel("kernels-community/vllm-flash-attn3")
|
| 12 |
+
_flash_attn_func = _k.flash_attn_func
|
| 13 |
+
except Exception as e:
|
| 14 |
+
_flash_attn_func = None
|
| 15 |
+
_kernels_err = e
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _ensure_fa3_available():
|
| 19 |
+
if _flash_attn_func is None:
|
| 20 |
+
raise ImportError(
|
| 21 |
+
"FlashAttention-3 via Hugging Face `kernels` is required. "
|
| 22 |
+
"Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n"
|
| 23 |
+
f"{_kernels_err}"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
| 27 |
+
def flash_attn_func(
|
| 28 |
+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
|
| 29 |
+
) -> torch.Tensor:
|
| 30 |
+
outputs, lse = _flash_attn_func(q, k, v, causal=causal)
|
| 31 |
+
return outputs
|
| 32 |
+
|
| 33 |
+
@flash_attn_func.register_fake
|
| 34 |
+
def _(q, k, v, **kwargs):
|
| 35 |
+
# two outputs:
|
| 36 |
+
# 1. output: (batch, seq_len, num_heads, head_dim)
|
| 37 |
+
# 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
|
| 38 |
+
meta_q = torch.empty_like(q).contiguous()
|
| 39 |
+
return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class QwenDoubleStreamAttnProcessorFA3:
|
| 43 |
+
"""
|
| 44 |
+
FA3-based attention processor for Qwen double-stream architecture.
|
| 45 |
+
Computes joint attention over concatenated [text, image] streams using vLLM FlashAttention-3
|
| 46 |
+
accessed via Hugging Face `kernels`.
|
| 47 |
+
|
| 48 |
+
Notes / limitations:
|
| 49 |
+
- General attention masks are not supported here (FA3 path). `is_causal=False` and no arbitrary mask.
|
| 50 |
+
- Optional windowed attention / sink tokens / softcap can be plumbed through if you use those features.
|
| 51 |
+
- Expects an available `apply_rotary_emb_qwen` in scope (same as your non-FA3 processor).
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
_attention_backend = "fa3" # for parity with your other processors, not used internally
|
| 55 |
+
|
| 56 |
+
def __init__(self):
|
| 57 |
+
_ensure_fa3_available()
|
| 58 |
+
|
| 59 |
+
@torch.no_grad()
|
| 60 |
+
def __call__(
|
| 61 |
+
self,
|
| 62 |
+
attn, # Attention module with to_q/to_k/to_v/add_*_proj, norms, to_out, to_add_out, and .heads
|
| 63 |
+
hidden_states: torch.FloatTensor, # (B, S_img, D_model) image stream
|
| 64 |
+
encoder_hidden_states: torch.FloatTensor = None, # (B, S_txt, D_model) text stream
|
| 65 |
+
encoder_hidden_states_mask: torch.FloatTensor = None, # unused in FA3 path
|
| 66 |
+
attention_mask: Optional[torch.FloatTensor] = None, # unused in FA3 path
|
| 67 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # (img_freqs, txt_freqs)
|
| 68 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
| 69 |
+
if encoder_hidden_states is None:
|
| 70 |
+
raise ValueError("QwenDoubleStreamAttnProcessorFA3 requires encoder_hidden_states (text stream).")
|
| 71 |
+
if attention_mask is not None:
|
| 72 |
+
# FA3 kernel path here does not consume arbitrary masks; fail fast to avoid silent correctness issues.
|
| 73 |
+
raise NotImplementedError("attention_mask is not supported in this FA3 implementation.")
|
| 74 |
+
|
| 75 |
+
_ensure_fa3_available()
|
| 76 |
+
|
| 77 |
+
B, S_img, _ = hidden_states.shape
|
| 78 |
+
S_txt = encoder_hidden_states.shape[1]
|
| 79 |
+
|
| 80 |
+
# ---- QKV projections (image/sample stream) ----
|
| 81 |
+
img_q = attn.to_q(hidden_states) # (B, S_img, D)
|
| 82 |
+
img_k = attn.to_k(hidden_states)
|
| 83 |
+
img_v = attn.to_v(hidden_states)
|
| 84 |
+
|
| 85 |
+
# ---- QKV projections (text/context stream) ----
|
| 86 |
+
txt_q = attn.add_q_proj(encoder_hidden_states) # (B, S_txt, D)
|
| 87 |
+
txt_k = attn.add_k_proj(encoder_hidden_states)
|
| 88 |
+
txt_v = attn.add_v_proj(encoder_hidden_states)
|
| 89 |
+
|
| 90 |
+
# ---- Reshape to (B, S, H, D_h) ----
|
| 91 |
+
H = attn.heads
|
| 92 |
+
img_q = img_q.unflatten(-1, (H, -1))
|
| 93 |
+
img_k = img_k.unflatten(-1, (H, -1))
|
| 94 |
+
img_v = img_v.unflatten(-1, (H, -1))
|
| 95 |
+
|
| 96 |
+
txt_q = txt_q.unflatten(-1, (H, -1))
|
| 97 |
+
txt_k = txt_k.unflatten(-1, (H, -1))
|
| 98 |
+
txt_v = txt_v.unflatten(-1, (H, -1))
|
| 99 |
+
|
| 100 |
+
# ---- Q/K normalization (per your module contract) ----
|
| 101 |
+
if getattr(attn, "norm_q", None) is not None:
|
| 102 |
+
img_q = attn.norm_q(img_q)
|
| 103 |
+
if getattr(attn, "norm_k", None) is not None:
|
| 104 |
+
img_k = attn.norm_k(img_k)
|
| 105 |
+
if getattr(attn, "norm_added_q", None) is not None:
|
| 106 |
+
txt_q = attn.norm_added_q(txt_q)
|
| 107 |
+
if getattr(attn, "norm_added_k", None) is not None:
|
| 108 |
+
txt_k = attn.norm_added_k(txt_k)
|
| 109 |
+
|
| 110 |
+
# ---- RoPE (Qwen variant) ----
|
| 111 |
+
if image_rotary_emb is not None:
|
| 112 |
+
img_freqs, txt_freqs = image_rotary_emb
|
| 113 |
+
# expects tensors shaped (B, S, H, D_h)
|
| 114 |
+
img_q = apply_rotary_emb_qwen(img_q, img_freqs, use_real=False)
|
| 115 |
+
img_k = apply_rotary_emb_qwen(img_k, img_freqs, use_real=False)
|
| 116 |
+
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs, use_real=False)
|
| 117 |
+
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs, use_real=False)
|
| 118 |
+
|
| 119 |
+
# ---- Joint attention over [text, image] along sequence axis ----
|
| 120 |
+
# Shapes: (B, S_total, H, D_h)
|
| 121 |
+
q = torch.cat([txt_q, img_q], dim=1)
|
| 122 |
+
k = torch.cat([txt_k, img_k], dim=1)
|
| 123 |
+
v = torch.cat([txt_v, img_v], dim=1)
|
| 124 |
+
|
| 125 |
+
# FlashAttention-3 path expects (B, S, H, D_h) and returns (out, softmax_lse)
|
| 126 |
+
out = flash_attn_func(q, k, v, causal=False) # out: (B, S_total, H, D_h)
|
| 127 |
+
|
| 128 |
+
# ---- Back to (B, S, D_model) ----
|
| 129 |
+
out = out.flatten(2, 3).to(q.dtype)
|
| 130 |
+
|
| 131 |
+
# Split back to text / image segments
|
| 132 |
+
txt_attn_out = out[:, :S_txt, :]
|
| 133 |
+
img_attn_out = out[:, S_txt:, :]
|
| 134 |
+
|
| 135 |
+
# ---- Output projections ----
|
| 136 |
+
img_attn_out = attn.to_out[0](img_attn_out)
|
| 137 |
+
if len(attn.to_out) > 1:
|
| 138 |
+
img_attn_out = attn.to_out[1](img_attn_out) # dropout if present
|
| 139 |
+
|
| 140 |
+
txt_attn_out = attn.to_add_out(txt_attn_out)
|
| 141 |
+
|
| 142 |
+
return img_attn_out, txt_attn_out
|
qwenimage/transformer_qwenimage.py
ADDED
|
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import functools
|
| 16 |
+
import math
|
| 17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 25 |
+
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 26 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 27 |
+
from diffusers.models.attention import FeedForward, AttentionMixin
|
| 28 |
+
from diffusers.models.attention_dispatch import dispatch_attention_fn
|
| 29 |
+
from diffusers.models.attention_processor import Attention
|
| 30 |
+
from diffusers.models.cache_utils import CacheMixin
|
| 31 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
| 32 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 33 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 34 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_timestep_embedding(
|
| 41 |
+
timesteps: torch.Tensor,
|
| 42 |
+
embedding_dim: int,
|
| 43 |
+
flip_sin_to_cos: bool = False,
|
| 44 |
+
downscale_freq_shift: float = 1,
|
| 45 |
+
scale: float = 1,
|
| 46 |
+
max_period: int = 10000,
|
| 47 |
+
) -> torch.Tensor:
|
| 48 |
+
"""
|
| 49 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 50 |
+
|
| 51 |
+
Args
|
| 52 |
+
timesteps (torch.Tensor):
|
| 53 |
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 54 |
+
embedding_dim (int):
|
| 55 |
+
the dimension of the output.
|
| 56 |
+
flip_sin_to_cos (bool):
|
| 57 |
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
| 58 |
+
downscale_freq_shift (float):
|
| 59 |
+
Controls the delta between frequencies between dimensions
|
| 60 |
+
scale (float):
|
| 61 |
+
Scaling factor applied to the embeddings.
|
| 62 |
+
max_period (int):
|
| 63 |
+
Controls the maximum frequency of the embeddings
|
| 64 |
+
Returns
|
| 65 |
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
| 66 |
+
"""
|
| 67 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 68 |
+
|
| 69 |
+
half_dim = embedding_dim // 2
|
| 70 |
+
exponent = -math.log(max_period) * torch.arange(
|
| 71 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
| 72 |
+
)
|
| 73 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 74 |
+
|
| 75 |
+
emb = torch.exp(exponent).to(timesteps.dtype)
|
| 76 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 77 |
+
|
| 78 |
+
# scale embeddings
|
| 79 |
+
emb = scale * emb
|
| 80 |
+
|
| 81 |
+
# concat sine and cosine embeddings
|
| 82 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 83 |
+
|
| 84 |
+
# flip sine and cosine embeddings
|
| 85 |
+
if flip_sin_to_cos:
|
| 86 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 87 |
+
|
| 88 |
+
# zero pad
|
| 89 |
+
if embedding_dim % 2 == 1:
|
| 90 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 91 |
+
return emb
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def apply_rotary_emb_qwen(
|
| 95 |
+
x: torch.Tensor,
|
| 96 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 97 |
+
use_real: bool = True,
|
| 98 |
+
use_real_unbind_dim: int = -1,
|
| 99 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 100 |
+
"""
|
| 101 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 102 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 103 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 104 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
x (`torch.Tensor`):
|
| 108 |
+
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
|
| 109 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 113 |
+
"""
|
| 114 |
+
if use_real:
|
| 115 |
+
cos, sin = freqs_cis # [S, D]
|
| 116 |
+
cos = cos[None, None]
|
| 117 |
+
sin = sin[None, None]
|
| 118 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 119 |
+
|
| 120 |
+
if use_real_unbind_dim == -1:
|
| 121 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 122 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
| 123 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 124 |
+
elif use_real_unbind_dim == -2:
|
| 125 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 126 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
| 127 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 128 |
+
else:
|
| 129 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 130 |
+
|
| 131 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 132 |
+
|
| 133 |
+
return out
|
| 134 |
+
else:
|
| 135 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 136 |
+
freqs_cis = freqs_cis.unsqueeze(1)
|
| 137 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 138 |
+
|
| 139 |
+
return x_out.type_as(x)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class QwenTimestepProjEmbeddings(nn.Module):
|
| 143 |
+
def __init__(self, embedding_dim):
|
| 144 |
+
super().__init__()
|
| 145 |
+
|
| 146 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
| 147 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 148 |
+
|
| 149 |
+
def forward(self, timestep, hidden_states):
|
| 150 |
+
timesteps_proj = self.time_proj(timestep)
|
| 151 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
|
| 152 |
+
|
| 153 |
+
conditioning = timesteps_emb
|
| 154 |
+
|
| 155 |
+
return conditioning
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class QwenEmbedRope(nn.Module):
|
| 159 |
+
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.theta = theta
|
| 162 |
+
self.axes_dim = axes_dim
|
| 163 |
+
pos_index = torch.arange(4096)
|
| 164 |
+
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
| 165 |
+
self.pos_freqs = torch.cat(
|
| 166 |
+
[
|
| 167 |
+
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
| 168 |
+
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
| 169 |
+
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
| 170 |
+
],
|
| 171 |
+
dim=1,
|
| 172 |
+
)
|
| 173 |
+
self.neg_freqs = torch.cat(
|
| 174 |
+
[
|
| 175 |
+
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
| 176 |
+
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
| 177 |
+
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
| 178 |
+
],
|
| 179 |
+
dim=1,
|
| 180 |
+
)
|
| 181 |
+
self.rope_cache = {}
|
| 182 |
+
|
| 183 |
+
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
|
| 184 |
+
self.scale_rope = scale_rope
|
| 185 |
+
|
| 186 |
+
def rope_params(self, index, dim, theta=10000):
|
| 187 |
+
"""
|
| 188 |
+
Args:
|
| 189 |
+
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
| 190 |
+
"""
|
| 191 |
+
assert dim % 2 == 0
|
| 192 |
+
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
| 193 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 194 |
+
return freqs
|
| 195 |
+
|
| 196 |
+
def forward(self, video_fhw, txt_seq_lens, device):
|
| 197 |
+
"""
|
| 198 |
+
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
| 199 |
+
txt_length: [bs] a list of 1 integers representing the length of the text
|
| 200 |
+
"""
|
| 201 |
+
if self.pos_freqs.device != device:
|
| 202 |
+
self.pos_freqs = self.pos_freqs.to(device)
|
| 203 |
+
self.neg_freqs = self.neg_freqs.to(device)
|
| 204 |
+
|
| 205 |
+
if isinstance(video_fhw, list):
|
| 206 |
+
video_fhw = video_fhw[0]
|
| 207 |
+
if not isinstance(video_fhw, list):
|
| 208 |
+
video_fhw = [video_fhw]
|
| 209 |
+
|
| 210 |
+
vid_freqs = []
|
| 211 |
+
max_vid_index = 0
|
| 212 |
+
for idx, fhw in enumerate(video_fhw):
|
| 213 |
+
frame, height, width = fhw
|
| 214 |
+
rope_key = f"{idx}_{height}_{width}"
|
| 215 |
+
|
| 216 |
+
if not torch.compiler.is_compiling():
|
| 217 |
+
if rope_key not in self.rope_cache:
|
| 218 |
+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
|
| 219 |
+
video_freq = self.rope_cache[rope_key]
|
| 220 |
+
else:
|
| 221 |
+
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
| 222 |
+
video_freq = video_freq.to(device)
|
| 223 |
+
vid_freqs.append(video_freq)
|
| 224 |
+
|
| 225 |
+
if self.scale_rope:
|
| 226 |
+
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
| 227 |
+
else:
|
| 228 |
+
max_vid_index = max(height, width, max_vid_index)
|
| 229 |
+
|
| 230 |
+
max_len = max(txt_seq_lens)
|
| 231 |
+
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
| 232 |
+
vid_freqs = torch.cat(vid_freqs, dim=0)
|
| 233 |
+
|
| 234 |
+
return vid_freqs, txt_freqs
|
| 235 |
+
|
| 236 |
+
@functools.lru_cache(maxsize=None)
|
| 237 |
+
def _compute_video_freqs(self, frame, height, width, idx=0):
|
| 238 |
+
seq_lens = frame * height * width
|
| 239 |
+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
| 240 |
+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
| 241 |
+
|
| 242 |
+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
| 243 |
+
if self.scale_rope:
|
| 244 |
+
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
| 245 |
+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
| 246 |
+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
| 247 |
+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
| 248 |
+
else:
|
| 249 |
+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
| 250 |
+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
| 251 |
+
|
| 252 |
+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
| 253 |
+
return freqs.clone().contiguous()
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class QwenDoubleStreamAttnProcessor2_0:
|
| 257 |
+
"""
|
| 258 |
+
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
|
| 259 |
+
implements joint attention computation where text and image streams are processed together.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
_attention_backend = None
|
| 263 |
+
|
| 264 |
+
def __init__(self):
|
| 265 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 266 |
+
raise ImportError(
|
| 267 |
+
"QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
def __call__(
|
| 271 |
+
self,
|
| 272 |
+
attn: Attention,
|
| 273 |
+
hidden_states: torch.FloatTensor, # Image stream
|
| 274 |
+
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
| 275 |
+
encoder_hidden_states_mask: torch.FloatTensor = None,
|
| 276 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 277 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 278 |
+
) -> torch.FloatTensor:
|
| 279 |
+
if encoder_hidden_states is None:
|
| 280 |
+
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
|
| 281 |
+
|
| 282 |
+
seq_txt = encoder_hidden_states.shape[1]
|
| 283 |
+
|
| 284 |
+
# Compute QKV for image stream (sample projections)
|
| 285 |
+
img_query = attn.to_q(hidden_states)
|
| 286 |
+
img_key = attn.to_k(hidden_states)
|
| 287 |
+
img_value = attn.to_v(hidden_states)
|
| 288 |
+
|
| 289 |
+
# Compute QKV for text stream (context projections)
|
| 290 |
+
txt_query = attn.add_q_proj(encoder_hidden_states)
|
| 291 |
+
txt_key = attn.add_k_proj(encoder_hidden_states)
|
| 292 |
+
txt_value = attn.add_v_proj(encoder_hidden_states)
|
| 293 |
+
|
| 294 |
+
# Reshape for multi-head attention
|
| 295 |
+
img_query = img_query.unflatten(-1, (attn.heads, -1))
|
| 296 |
+
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
| 297 |
+
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
| 298 |
+
|
| 299 |
+
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
| 300 |
+
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
| 301 |
+
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
| 302 |
+
|
| 303 |
+
# Apply QK normalization
|
| 304 |
+
if attn.norm_q is not None:
|
| 305 |
+
img_query = attn.norm_q(img_query)
|
| 306 |
+
if attn.norm_k is not None:
|
| 307 |
+
img_key = attn.norm_k(img_key)
|
| 308 |
+
if attn.norm_added_q is not None:
|
| 309 |
+
txt_query = attn.norm_added_q(txt_query)
|
| 310 |
+
if attn.norm_added_k is not None:
|
| 311 |
+
txt_key = attn.norm_added_k(txt_key)
|
| 312 |
+
|
| 313 |
+
# Apply RoPE
|
| 314 |
+
if image_rotary_emb is not None:
|
| 315 |
+
img_freqs, txt_freqs = image_rotary_emb
|
| 316 |
+
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
|
| 317 |
+
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
|
| 318 |
+
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
|
| 319 |
+
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
|
| 320 |
+
|
| 321 |
+
# Concatenate for joint attention
|
| 322 |
+
# Order: [text, image]
|
| 323 |
+
joint_query = torch.cat([txt_query, img_query], dim=1)
|
| 324 |
+
joint_key = torch.cat([txt_key, img_key], dim=1)
|
| 325 |
+
joint_value = torch.cat([txt_value, img_value], dim=1)
|
| 326 |
+
|
| 327 |
+
# Compute joint attention
|
| 328 |
+
joint_hidden_states = dispatch_attention_fn(
|
| 329 |
+
joint_query,
|
| 330 |
+
joint_key,
|
| 331 |
+
joint_value,
|
| 332 |
+
attn_mask=attention_mask,
|
| 333 |
+
dropout_p=0.0,
|
| 334 |
+
is_causal=False,
|
| 335 |
+
backend=self._attention_backend,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Reshape back
|
| 339 |
+
joint_hidden_states = joint_hidden_states.flatten(2, 3)
|
| 340 |
+
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
|
| 341 |
+
|
| 342 |
+
# Split attention outputs back
|
| 343 |
+
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
|
| 344 |
+
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
| 345 |
+
|
| 346 |
+
# Apply output projections
|
| 347 |
+
img_attn_output = attn.to_out[0](img_attn_output)
|
| 348 |
+
if len(attn.to_out) > 1:
|
| 349 |
+
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
| 350 |
+
|
| 351 |
+
txt_attn_output = attn.to_add_out(txt_attn_output)
|
| 352 |
+
|
| 353 |
+
return img_attn_output, txt_attn_output
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
@maybe_allow_in_graph
|
| 357 |
+
class QwenImageTransformerBlock(nn.Module):
|
| 358 |
+
def __init__(
|
| 359 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
| 360 |
+
):
|
| 361 |
+
super().__init__()
|
| 362 |
+
|
| 363 |
+
self.dim = dim
|
| 364 |
+
self.num_attention_heads = num_attention_heads
|
| 365 |
+
self.attention_head_dim = attention_head_dim
|
| 366 |
+
|
| 367 |
+
# Image processing modules
|
| 368 |
+
self.img_mod = nn.Sequential(
|
| 369 |
+
nn.SiLU(),
|
| 370 |
+
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
|
| 371 |
+
)
|
| 372 |
+
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 373 |
+
self.attn = Attention(
|
| 374 |
+
query_dim=dim,
|
| 375 |
+
cross_attention_dim=None, # Enable cross attention for joint computation
|
| 376 |
+
added_kv_proj_dim=dim, # Enable added KV projections for text stream
|
| 377 |
+
dim_head=attention_head_dim,
|
| 378 |
+
heads=num_attention_heads,
|
| 379 |
+
out_dim=dim,
|
| 380 |
+
context_pre_only=False,
|
| 381 |
+
bias=True,
|
| 382 |
+
processor=QwenDoubleStreamAttnProcessor2_0(),
|
| 383 |
+
qk_norm=qk_norm,
|
| 384 |
+
eps=eps,
|
| 385 |
+
)
|
| 386 |
+
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 387 |
+
self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 388 |
+
|
| 389 |
+
# Text processing modules
|
| 390 |
+
self.txt_mod = nn.Sequential(
|
| 391 |
+
nn.SiLU(),
|
| 392 |
+
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
|
| 393 |
+
)
|
| 394 |
+
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 395 |
+
# Text doesn't need separate attention - it's handled by img_attn joint computation
|
| 396 |
+
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 397 |
+
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 398 |
+
|
| 399 |
+
def _modulate(self, x, mod_params):
|
| 400 |
+
"""Apply modulation to input tensor"""
|
| 401 |
+
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
| 402 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
| 403 |
+
|
| 404 |
+
def forward(
|
| 405 |
+
self,
|
| 406 |
+
hidden_states: torch.Tensor,
|
| 407 |
+
encoder_hidden_states: torch.Tensor,
|
| 408 |
+
encoder_hidden_states_mask: torch.Tensor,
|
| 409 |
+
temb: torch.Tensor,
|
| 410 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 411 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 412 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 413 |
+
# Get modulation parameters for both streams
|
| 414 |
+
img_mod_params = self.img_mod(temb) # [B, 6*dim]
|
| 415 |
+
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
|
| 416 |
+
|
| 417 |
+
# Split modulation parameters for norm1 and norm2
|
| 418 |
+
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
| 419 |
+
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
| 420 |
+
|
| 421 |
+
# Process image stream - norm1 + modulation
|
| 422 |
+
img_normed = self.img_norm1(hidden_states)
|
| 423 |
+
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
| 424 |
+
|
| 425 |
+
# Process text stream - norm1 + modulation
|
| 426 |
+
txt_normed = self.txt_norm1(encoder_hidden_states)
|
| 427 |
+
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
| 428 |
+
|
| 429 |
+
# Use QwenAttnProcessor2_0 for joint attention computation
|
| 430 |
+
# This directly implements the DoubleStreamLayerMegatron logic:
|
| 431 |
+
# 1. Computes QKV for both streams
|
| 432 |
+
# 2. Applies QK normalization and RoPE
|
| 433 |
+
# 3. Concatenates and runs joint attention
|
| 434 |
+
# 4. Splits results back to separate streams
|
| 435 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 436 |
+
attn_output = self.attn(
|
| 437 |
+
hidden_states=img_modulated, # Image stream (will be processed as "sample")
|
| 438 |
+
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
|
| 439 |
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
| 440 |
+
image_rotary_emb=image_rotary_emb,
|
| 441 |
+
**joint_attention_kwargs,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
|
| 445 |
+
img_attn_output, txt_attn_output = attn_output
|
| 446 |
+
|
| 447 |
+
# Apply attention gates and add residual (like in Megatron)
|
| 448 |
+
hidden_states = hidden_states + img_gate1 * img_attn_output
|
| 449 |
+
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
| 450 |
+
|
| 451 |
+
# Process image stream - norm2 + MLP
|
| 452 |
+
img_normed2 = self.img_norm2(hidden_states)
|
| 453 |
+
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
| 454 |
+
img_mlp_output = self.img_mlp(img_modulated2)
|
| 455 |
+
hidden_states = hidden_states + img_gate2 * img_mlp_output
|
| 456 |
+
|
| 457 |
+
# Process text stream - norm2 + MLP
|
| 458 |
+
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
| 459 |
+
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
| 460 |
+
txt_mlp_output = self.txt_mlp(txt_modulated2)
|
| 461 |
+
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
| 462 |
+
|
| 463 |
+
# Clip to prevent overflow for fp16
|
| 464 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 465 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 466 |
+
if hidden_states.dtype == torch.float16:
|
| 467 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 468 |
+
|
| 469 |
+
return encoder_hidden_states, hidden_states
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin):
|
| 473 |
+
"""
|
| 474 |
+
The Transformer model introduced in Qwen.
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
patch_size (`int`, defaults to `2`):
|
| 478 |
+
Patch size to turn the input data into small patches.
|
| 479 |
+
in_channels (`int`, defaults to `64`):
|
| 480 |
+
The number of channels in the input.
|
| 481 |
+
out_channels (`int`, *optional*, defaults to `None`):
|
| 482 |
+
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
| 483 |
+
num_layers (`int`, defaults to `60`):
|
| 484 |
+
The number of layers of dual stream DiT blocks to use.
|
| 485 |
+
attention_head_dim (`int`, defaults to `128`):
|
| 486 |
+
The number of dimensions to use for each attention head.
|
| 487 |
+
num_attention_heads (`int`, defaults to `24`):
|
| 488 |
+
The number of attention heads to use.
|
| 489 |
+
joint_attention_dim (`int`, defaults to `3584`):
|
| 490 |
+
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
| 491 |
+
`encoder_hidden_states`).
|
| 492 |
+
guidance_embeds (`bool`, defaults to `False`):
|
| 493 |
+
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
| 494 |
+
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
| 495 |
+
The dimensions to use for the rotary positional embeddings.
|
| 496 |
+
"""
|
| 497 |
+
|
| 498 |
+
_supports_gradient_checkpointing = True
|
| 499 |
+
_no_split_modules = ["QwenImageTransformerBlock"]
|
| 500 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 501 |
+
_repeated_blocks = ["QwenImageTransformerBlock"]
|
| 502 |
+
|
| 503 |
+
@register_to_config
|
| 504 |
+
def __init__(
|
| 505 |
+
self,
|
| 506 |
+
patch_size: int = 2,
|
| 507 |
+
in_channels: int = 64,
|
| 508 |
+
out_channels: Optional[int] = 16,
|
| 509 |
+
num_layers: int = 60,
|
| 510 |
+
attention_head_dim: int = 128,
|
| 511 |
+
num_attention_heads: int = 24,
|
| 512 |
+
joint_attention_dim: int = 3584,
|
| 513 |
+
guidance_embeds: bool = False, # TODO: this should probably be removed
|
| 514 |
+
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
| 515 |
+
):
|
| 516 |
+
super().__init__()
|
| 517 |
+
self.out_channels = out_channels or in_channels
|
| 518 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 519 |
+
|
| 520 |
+
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
| 521 |
+
|
| 522 |
+
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
| 523 |
+
|
| 524 |
+
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
|
| 525 |
+
|
| 526 |
+
self.img_in = nn.Linear(in_channels, self.inner_dim)
|
| 527 |
+
self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 528 |
+
|
| 529 |
+
self.transformer_blocks = nn.ModuleList(
|
| 530 |
+
[
|
| 531 |
+
QwenImageTransformerBlock(
|
| 532 |
+
dim=self.inner_dim,
|
| 533 |
+
num_attention_heads=num_attention_heads,
|
| 534 |
+
attention_head_dim=attention_head_dim,
|
| 535 |
+
)
|
| 536 |
+
for _ in range(num_layers)
|
| 537 |
+
]
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 541 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 542 |
+
|
| 543 |
+
self.gradient_checkpointing = False
|
| 544 |
+
|
| 545 |
+
def forward(
|
| 546 |
+
self,
|
| 547 |
+
hidden_states: torch.Tensor,
|
| 548 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 549 |
+
encoder_hidden_states_mask: torch.Tensor = None,
|
| 550 |
+
timestep: torch.LongTensor = None,
|
| 551 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 552 |
+
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
| 553 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 554 |
+
return_dict: bool = True,
|
| 555 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 556 |
+
"""
|
| 557 |
+
The [`QwenTransformer2DModel`] forward method.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
| 561 |
+
Input `hidden_states`.
|
| 562 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
| 563 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 564 |
+
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
|
| 565 |
+
Mask of the input conditions.
|
| 566 |
+
timestep ( `torch.LongTensor`):
|
| 567 |
+
Used to indicate denoising step.
|
| 568 |
+
attention_kwargs (`dict`, *optional*):
|
| 569 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 570 |
+
`self.processor` in
|
| 571 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 572 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 573 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 574 |
+
tuple.
|
| 575 |
+
|
| 576 |
+
Returns:
|
| 577 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 578 |
+
`tuple` where the first element is the sample tensor.
|
| 579 |
+
"""
|
| 580 |
+
if attention_kwargs is not None:
|
| 581 |
+
attention_kwargs = attention_kwargs.copy()
|
| 582 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 583 |
+
else:
|
| 584 |
+
lora_scale = 1.0
|
| 585 |
+
|
| 586 |
+
if USE_PEFT_BACKEND:
|
| 587 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 588 |
+
scale_lora_layers(self, lora_scale)
|
| 589 |
+
else:
|
| 590 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 591 |
+
logger.warning(
|
| 592 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
hidden_states = self.img_in(hidden_states)
|
| 596 |
+
|
| 597 |
+
timestep = timestep.to(hidden_states.dtype)
|
| 598 |
+
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
| 599 |
+
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
| 600 |
+
|
| 601 |
+
if guidance is not None:
|
| 602 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 603 |
+
|
| 604 |
+
temb = (
|
| 605 |
+
self.time_text_embed(timestep, hidden_states)
|
| 606 |
+
if guidance is None
|
| 607 |
+
else self.time_text_embed(timestep, guidance, hidden_states)
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 611 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 612 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 613 |
+
block,
|
| 614 |
+
hidden_states,
|
| 615 |
+
encoder_hidden_states,
|
| 616 |
+
encoder_hidden_states_mask,
|
| 617 |
+
temb,
|
| 618 |
+
image_rotary_emb,
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
else:
|
| 622 |
+
encoder_hidden_states, hidden_states = block(
|
| 623 |
+
hidden_states=hidden_states,
|
| 624 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 625 |
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
| 626 |
+
temb=temb,
|
| 627 |
+
image_rotary_emb=image_rotary_emb,
|
| 628 |
+
joint_attention_kwargs=attention_kwargs,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# Use only the image part (hidden_states) from the dual-stream blocks
|
| 632 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 633 |
+
output = self.proj_out(hidden_states)
|
| 634 |
+
|
| 635 |
+
if USE_PEFT_BACKEND:
|
| 636 |
+
# remove `lora_scale` from each PEFT layer
|
| 637 |
+
unscale_lora_layers(self, lora_scale)
|
| 638 |
+
|
| 639 |
+
if not return_dict:
|
| 640 |
+
return (output,)
|
| 641 |
+
|
| 642 |
+
return Transformer2DModelOutput(sample=output)
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git+https://github.com/huggingface/transformers.git@v4.57.3
|
| 2 |
+
git+https://github.com/huggingface/accelerate.git
|
| 3 |
+
git+https://github.com/huggingface/diffusers.git
|
| 4 |
+
git+https://github.com/huggingface/peft.git
|
| 5 |
+
huggingface_hub
|
| 6 |
+
sentencepiece
|
| 7 |
+
torchvision
|
| 8 |
+
supervision
|
| 9 |
+
kernels
|
| 10 |
+
spaces
|
| 11 |
+
hf_xet
|
| 12 |
+
torch==2.9.1
|
| 13 |
+
numpy
|
| 14 |
+
av
|
setup_manager.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
# Configuration
|
| 6 |
+
WORKSPACE_DIR = "/workspace"
|
| 7 |
+
VENV_DIR = os.path.join(WORKSPACE_DIR, "venv")
|
| 8 |
+
APPS_DIR = os.path.join(WORKSPACE_DIR, "apps")
|
| 9 |
+
REPO_DIR = os.path.join(WORKSPACE_DIR, "Qwen-Image-Edit")
|
| 10 |
+
HF_TOKEN = "YOUR_HF_TOKEN_HERE"
|
| 11 |
+
|
| 12 |
+
# Cache and Temp Directories (Strictly on persistent drive)
|
| 13 |
+
CACHE_BASE = os.path.join(WORKSPACE_DIR, "cache")
|
| 14 |
+
TMP_DIR = os.path.join(WORKSPACE_DIR, "tmp")
|
| 15 |
+
PIP_CACHE = os.path.join(CACHE_BASE, "pip")
|
| 16 |
+
HF_HOME = os.path.join(CACHE_BASE, "huggingface")
|
| 17 |
+
|
| 18 |
+
def ensure_dirs():
|
| 19 |
+
"""Ensures all necessary persistent directories exist."""
|
| 20 |
+
dirs = [APPS_DIR, REPO_DIR, CACHE_BASE, TMP_DIR, PIP_CACHE, HF_HOME]
|
| 21 |
+
for d in dirs:
|
| 22 |
+
if not os.path.exists(d):
|
| 23 |
+
os.makedirs(d)
|
| 24 |
+
print(f"Created directory: {d}")
|
| 25 |
+
|
| 26 |
+
def run_command(command, cwd=None, env=None):
|
| 27 |
+
"""Runs a shell command and prints output."""
|
| 28 |
+
print(f"Running: {command}")
|
| 29 |
+
current_env = os.environ.copy()
|
| 30 |
+
|
| 31 |
+
# Force use of persistent directories
|
| 32 |
+
current_env["TMPDIR"] = TMP_DIR
|
| 33 |
+
current_env["PIP_CACHE_DIR"] = PIP_CACHE
|
| 34 |
+
current_env["HF_HOME"] = HF_HOME
|
| 35 |
+
|
| 36 |
+
if env:
|
| 37 |
+
current_env.update(env)
|
| 38 |
+
|
| 39 |
+
process = subprocess.Popen(
|
| 40 |
+
command,
|
| 41 |
+
shell=True,
|
| 42 |
+
stdout=subprocess.PIPE,
|
| 43 |
+
stderr=subprocess.STDOUT,
|
| 44 |
+
text=True,
|
| 45 |
+
cwd=cwd,
|
| 46 |
+
env=current_env
|
| 47 |
+
)
|
| 48 |
+
for line in process.stdout:
|
| 49 |
+
print(line, end="")
|
| 50 |
+
process.wait()
|
| 51 |
+
if process.returncode != 0:
|
| 52 |
+
print(f"Command failed with return code {process.returncode}")
|
| 53 |
+
return process.returncode
|
| 54 |
+
|
| 55 |
+
def setup_venv():
|
| 56 |
+
"""Sets up a persistent virtual environment in /workspace."""
|
| 57 |
+
if not os.path.exists(VENV_DIR):
|
| 58 |
+
print(f"Creating virtual environment in {VENV_DIR}...")
|
| 59 |
+
run_command(f"python3 -m venv {VENV_DIR}")
|
| 60 |
+
else:
|
| 61 |
+
print("Virtual environment already exists.")
|
| 62 |
+
|
| 63 |
+
def install_package(package_name):
|
| 64 |
+
"""Installs a pip package into the persistent venv."""
|
| 65 |
+
pip_path = os.path.join(VENV_DIR, "bin", "pip")
|
| 66 |
+
run_command(f"{pip_path} install {package_name}")
|
| 67 |
+
|
| 68 |
+
def install_git_xet():
|
| 69 |
+
"""Installs git-xet using the huggingface script."""
|
| 70 |
+
print("Installing git-xet...")
|
| 71 |
+
run_command("curl -LsSf https://huggingface.co/install-git-xet.sh | bash")
|
| 72 |
+
run_command("git xet install")
|
| 73 |
+
|
| 74 |
+
def install_hf_cli():
|
| 75 |
+
"""Installs Hugging Face CLI."""
|
| 76 |
+
print("Installing Hugging Face CLI...")
|
| 77 |
+
run_command("curl -LsSf https://hf.co/cli/install.sh | bash")
|
| 78 |
+
|
| 79 |
+
def download_space():
|
| 80 |
+
"""Downloads the Qwen Space using hf cli."""
|
| 81 |
+
if not os.path.exists(REPO_DIR):
|
| 82 |
+
os.makedirs(REPO_DIR)
|
| 83 |
+
|
| 84 |
+
print(f"Downloading Space to {REPO_DIR}...")
|
| 85 |
+
# Using full path to hf if it's in ~/.local/bin
|
| 86 |
+
hf_path = os.path.expanduser("~/.local/bin/hf")
|
| 87 |
+
if not os.path.exists(hf_path):
|
| 88 |
+
hf_path = "hf" # fallback to PATH
|
| 89 |
+
|
| 90 |
+
env = {"HF_TOKEN": HF_TOKEN}
|
| 91 |
+
run_command(f"{hf_path} download Pr0f3ssi0n4ln00b/Qwen-Image-Edit-Rapid-AIO-Loras-Experimental --repo-type=space --local-dir {REPO_DIR}", env=env)
|
| 92 |
+
|
| 93 |
+
def create_app_file(filename, content):
|
| 94 |
+
"""Creates/Updates a file in the apps directory."""
|
| 95 |
+
if not os.path.exists(APPS_DIR):
|
| 96 |
+
os.makedirs(APPS_DIR)
|
| 97 |
+
|
| 98 |
+
filepath = os.path.join(APPS_DIR, filename)
|
| 99 |
+
with open(filepath, "w") as f:
|
| 100 |
+
f.write(content)
|
| 101 |
+
print(f"Created/Updated: {filepath}")
|
| 102 |
+
|
| 103 |
+
def patch_app():
|
| 104 |
+
"""Patches app.py to optimize for VRAM and fix OOM issues."""
|
| 105 |
+
app_path = os.path.join(REPO_DIR, "app.py")
|
| 106 |
+
if not os.path.exists(app_path):
|
| 107 |
+
print(f"Warning: {app_path} not found, cannot patch.")
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
print("Patching app.py for memory optimization...")
|
| 111 |
+
with open(app_path, "r") as f:
|
| 112 |
+
content = f.read()
|
| 113 |
+
|
| 114 |
+
# 1. Update transformer loading to use device_map="auto" and low_cpu_mem_usage
|
| 115 |
+
content = content.replace(
|
| 116 |
+
'device_map="cuda",',
|
| 117 |
+
'device_map="auto",\n low_cpu_mem_usage=True,'
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# 2. Remove redundant .to(device) which causes OOM
|
| 121 |
+
content = content.replace(').to(device)', ')')
|
| 122 |
+
|
| 123 |
+
# 3. Enable model CPU offload to save VRAM
|
| 124 |
+
if "p.enable_model_cpu_offload()" not in content:
|
| 125 |
+
content = content.replace(
|
| 126 |
+
'return p',
|
| 127 |
+
'p.enable_model_cpu_offload()\n return p'
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# 4. Disable FA3 Processor (to avoid hangs/compilation issues)
|
| 131 |
+
content = content.replace(
|
| 132 |
+
'pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())',
|
| 133 |
+
'print("Skipping FA3 optimization for stability.")'
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# 5. Fix launch parameters for visibility and accessibility
|
| 137 |
+
content = content.replace(
|
| 138 |
+
'demo.queue(max_size=30).launch(',
|
| 139 |
+
'demo.queue(max_size=30).launch(server_name="0.0.0.0", share=True, '
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# 6. Ensure spaces.GPU is handled (if it blocks)
|
| 143 |
+
# Usually it's fine, but let's be safe and mock it if env isn't right
|
| 144 |
+
if 'import spaces' in content and 'class spaces:' not in content:
|
| 145 |
+
content = 'import sys\ntry:\n import spaces\nexcept ImportError:\n class spaces:\n @staticmethod\n def GPU(f): return f\nsys.modules["spaces"] = sys.modules.get("spaces", spaces)\n' + content
|
| 146 |
+
|
| 147 |
+
# 7. Add missing LORA_PRESET_PROMPTS (Robust append)
|
| 148 |
+
additional_prompts_map = {
|
| 149 |
+
"Consistance": "improve consistency and quality of the generated image",
|
| 150 |
+
"F2P": "transform the image into a high-quality photo with realistic details",
|
| 151 |
+
"Multiple-Angles": "change the camera angle of the image",
|
| 152 |
+
"Light-Restoration": "Remove shadows and relight the image using soft lighting",
|
| 153 |
+
"Relight": "Relight the image with cinematic lighting",
|
| 154 |
+
"Multi-Angle-Lighting": "Change the lighting direction and intensity",
|
| 155 |
+
"Edit-Skin": "Enhance skin textures and natural details",
|
| 156 |
+
"Next-Scene": "Generate the next scene based on the current image",
|
| 157 |
+
"Flat-Log": "Desaturate and lower contrast for a flat log look",
|
| 158 |
+
"Upscale-Image": "Enhance and sharpen the image details",
|
| 159 |
+
"BFS-Best-FaceSwap": "head_swap : start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure, mouth, lips and front head of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
|
| 160 |
+
"BFS-Best-FaceSwap-merge": "head_swap : start with Picture 1 as the base image, keeping its lighting, environment, and background. remove the head from Picture 1 completely and replace it with the head from Picture 2, strictly preserving the hair, eye color, and nose structure, mouth, lips and front head of Picture 2. copy the eye direction, head rotation, and micro-expressions from Picture 1. high quality, sharp details, 4k",
|
| 161 |
+
"Qwen-lora-nsfw": "Convert this picture to artistic style.", # Default prompt
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
# 9. Add new LoRA to ADAPTER_SPECS
|
| 165 |
+
new_lora_config = """
|
| 166 |
+
"Qwen-lora-nsfw": {
|
| 167 |
+
"type": "single",
|
| 168 |
+
"repo": "wiikoo/Qwen-lora-nsfw",
|
| 169 |
+
"weights": "loras/qwen_image_edit_remove-clothing_v1.0.safetensors",
|
| 170 |
+
"adapter_name": "qwen-lora-nsfw",
|
| 171 |
+
"strength": 1.0,
|
| 172 |
+
},
|
| 173 |
+
"""
|
| 174 |
+
if '"Qwen-lora-nsfw":' not in content:
|
| 175 |
+
content = content.replace(
|
| 176 |
+
'ADAPTER_SPECS = {',
|
| 177 |
+
'ADAPTER_SPECS = {' + new_lora_config
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if "Manual Patch for missing prompts" not in content:
|
| 181 |
+
content += "\n\n# Manual Patch for missing prompts\ntry:\n LORA_PRESET_PROMPTS.update({\n"
|
| 182 |
+
for key, val in additional_prompts_map.items():
|
| 183 |
+
content += f' "{key}": "{val}",\n'
|
| 184 |
+
content += " })\nexcept NameError:\n pass\n"
|
| 185 |
+
|
| 186 |
+
# 8. Modify on_lora_change_ui to ALWAYS update the prompt if a style is picked
|
| 187 |
+
# (or at least be more aggressive)
|
| 188 |
+
new_ui_logic = """
|
| 189 |
+
def on_lora_change_ui(selected_lora, current_prompt, extras_condition_only):
|
| 190 |
+
# Always provide the preset if selected
|
| 191 |
+
prompt_val = current_prompt
|
| 192 |
+
if selected_lora != NONE_LORA:
|
| 193 |
+
preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
|
| 194 |
+
if preset:
|
| 195 |
+
prompt_val = preset
|
| 196 |
+
|
| 197 |
+
prompt_update = gr.update(value=prompt_val)
|
| 198 |
+
"""
|
| 199 |
+
# Find the old function and replace it
|
| 200 |
+
start_marker = "def on_lora_change_ui"
|
| 201 |
+
end_marker = "return prompt_update, img2_update, extras_update"
|
| 202 |
+
|
| 203 |
+
if start_marker in content and end_marker in content:
|
| 204 |
+
import re
|
| 205 |
+
content = re.sub(
|
| 206 |
+
r"def on_lora_change_ui\(.*?\):.*?return prompt_update, img2_update, extras_update",
|
| 207 |
+
new_ui_logic + "\n # Image2 visibility/label\n if lora_requires_two_images(selected_lora):\n img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))\n else:\n img2_update = gr.update(visible=False, value=None, label='Upload Reference (Image 2)')\n\n # Extra references routing default\n if selected_lora in ('BFS-Best-FaceSwap', 'BFS-Best-FaceSwap-merge', 'AnyPose'):\n extras_update = gr.update(value=True)\n else:\n extras_update = gr.update(value=extras_condition_only)\n\n return prompt_update, img2_update, extras_update",
|
| 208 |
+
content,
|
| 209 |
+
flags=re.DOTALL
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
with open(app_path, "w") as f:
|
| 213 |
+
f.write(content)
|
| 214 |
+
|
| 215 |
+
# --- NEW UI PATCHES ---
|
| 216 |
+
with open(app_path, "r") as f:
|
| 217 |
+
content = f.read()
|
| 218 |
+
|
| 219 |
+
# 10. Implement missing _append_to_gallery function
|
| 220 |
+
append_fn = """
|
| 221 |
+
def _append_to_gallery(existing_gallery, new_image):
|
| 222 |
+
if existing_gallery is None:
|
| 223 |
+
return [new_image]
|
| 224 |
+
if not isinstance(existing_gallery, list):
|
| 225 |
+
existing_gallery = [existing_gallery]
|
| 226 |
+
existing_gallery.append(new_image)
|
| 227 |
+
return existing_gallery
|
| 228 |
+
"""
|
| 229 |
+
if "def _append_to_gallery" not in content:
|
| 230 |
+
content = content.replace(
|
| 231 |
+
'# UI helpers: output routing + derived conditioning',
|
| 232 |
+
'# UI helpers: output routing + derived conditioning\n' + append_fn
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# 11. Remove height constraints from main image components
|
| 236 |
+
content = content.replace('height=290)', ')')
|
| 237 |
+
content = content.replace('height=350)', ')')
|
| 238 |
+
|
| 239 |
+
# 12. Strip out gr.Examples block to declutter UI
|
| 240 |
+
# We find the start of gr.Examples and the end of its call
|
| 241 |
+
if "gr.Examples(" in content:
|
| 242 |
+
import re
|
| 243 |
+
content = re.sub(
|
| 244 |
+
r"gr\.Examples\([\s\S]*?label=\"Examples\"[\s\S]*?\)",
|
| 245 |
+
"# Examples removed automatically by setup_manager",
|
| 246 |
+
content
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
with open(app_path, "w") as f:
|
| 250 |
+
f.write(content)
|
| 251 |
+
# --- END NEW UI PATCHES ---
|
| 252 |
+
|
| 253 |
+
# --- 3D CAMERA AND PROMPT CLEARING PATCHES ---
|
| 254 |
+
with open(app_path, "r") as f:
|
| 255 |
+
content = f.read()
|
| 256 |
+
|
| 257 |
+
# Import the custom 3D Camera control safely at the top
|
| 258 |
+
if "update_prompt_with_camera" not in content:
|
| 259 |
+
content = content.replace("import os", "import os\nfrom camera_control_ui import CameraControl3D, build_camera_prompt, update_prompt_with_camera")
|
| 260 |
+
|
| 261 |
+
# Add the 3D Camera LoRA to ADAPTER_SPECS
|
| 262 |
+
camera_lora_config = """
|
| 263 |
+
"3D-Camera": {
|
| 264 |
+
"type": "single",
|
| 265 |
+
"repo": "fal/Qwen-Image-Edit-2511-Multiple-Angles-LoRA",
|
| 266 |
+
"weights": "qwen-image-edit-2511-multiple-angles-lora.safetensors",
|
| 267 |
+
"adapter_name": "angles",
|
| 268 |
+
"strength": 1.0,
|
| 269 |
+
},
|
| 270 |
+
"""
|
| 271 |
+
if '"3D-Camera":' not in content:
|
| 272 |
+
content = content.replace(
|
| 273 |
+
'ADAPTER_SPECS = {',
|
| 274 |
+
'ADAPTER_SPECS = {' + camera_lora_config
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Patch on_lora_change_ui to clear prompt if no preset exists and toggle 3D camera visibility
|
| 278 |
+
prompt_clear_logic = """
|
| 279 |
+
def on_lora_change_ui(selected_lora, current_prompt, extras_condition_only):
|
| 280 |
+
prompt_val = current_prompt
|
| 281 |
+
if selected_lora != NONE_LORA:
|
| 282 |
+
preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
|
| 283 |
+
if preset:
|
| 284 |
+
prompt_val = preset
|
| 285 |
+
else:
|
| 286 |
+
prompt_val = "" # CLEAR THE PROMPT IF ACTIVE BUT NO PRESET
|
| 287 |
+
|
| 288 |
+
prompt_update = gr.update(value=prompt_val)
|
| 289 |
+
camera_update = gr.update(visible=(selected_lora == "3D-Camera"))
|
| 290 |
+
|
| 291 |
+
# Image2 visibility/label
|
| 292 |
+
if lora_requires_two_images(selected_lora):
|
| 293 |
+
img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
|
| 294 |
+
else:
|
| 295 |
+
img2_update = gr.update(visible=False, value=None, label='Upload Reference (Image 2)')
|
| 296 |
+
|
| 297 |
+
# Extra references routing default
|
| 298 |
+
if selected_lora in ('BFS-Best-FaceSwap', 'BFS-Best-FaceSwap-merge', 'AnyPose'):
|
| 299 |
+
extras_update = gr.update(value=True)
|
| 300 |
+
else:
|
| 301 |
+
extras_update = gr.update(value=extras_condition_only)
|
| 302 |
+
|
| 303 |
+
return prompt_update, img2_update, extras_update, camera_update
|
| 304 |
+
"""
|
| 305 |
+
old_on_lora = """
|
| 306 |
+
def on_lora_change_ui(selected_lora, current_prompt, extras_condition_only):
|
| 307 |
+
# Always provide the preset if selected
|
| 308 |
+
prompt_val = current_prompt
|
| 309 |
+
if selected_lora != NONE_LORA:
|
| 310 |
+
preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
|
| 311 |
+
if preset:
|
| 312 |
+
prompt_val = preset
|
| 313 |
+
|
| 314 |
+
prompt_update = gr.update(value=prompt_val)
|
| 315 |
+
|
| 316 |
+
# Image2 visibility/label
|
| 317 |
+
if lora_requires_two_images(selected_lora):
|
| 318 |
+
img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
|
| 319 |
+
else:
|
| 320 |
+
img2_update = gr.update(visible=False, value=None, label='Upload Reference (Image 2)')
|
| 321 |
+
|
| 322 |
+
# Extra references routing default
|
| 323 |
+
if selected_lora in ('BFS-Best-FaceSwap', 'BFS-Best-FaceSwap-merge', 'AnyPose'):
|
| 324 |
+
extras_update = gr.update(value=True)
|
| 325 |
+
else:
|
| 326 |
+
extras_update = gr.update(value=extras_condition_only)
|
| 327 |
+
|
| 328 |
+
return prompt_update, img2_update, extras_update
|
| 329 |
+
"""
|
| 330 |
+
if "camera_update = gr.update(visible" not in content:
|
| 331 |
+
content = content.replace(old_on_lora.strip(), prompt_clear_logic.strip())
|
| 332 |
+
|
| 333 |
+
# We also need to update the caller
|
| 334 |
+
content = content.replace(
|
| 335 |
+
"outputs=[prompt, input_image_2, extras_condition_only],",
|
| 336 |
+
"outputs=[prompt, input_image_2, extras_condition_only, camera_container],"
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# Inject the 3D Camera UI Block right below input_image_2 definition
|
| 340 |
+
camera_ui_block = """
|
| 341 |
+
input_image_2 = gr.Image(label="Upload Reference (Image 2)", type="pil", height=290, visible=False)
|
| 342 |
+
|
| 343 |
+
with gr.Column(visible=False) as camera_container:
|
| 344 |
+
gr.Markdown("### 🎮 3D Camera Control\\n*Drag handles: 🟢 Azimuth, 🩷 Elevation, 🟠 Distance*")
|
| 345 |
+
camera_3d = CameraControl3D(value={"azimuth": 0, "elevation": 0, "distance": 1.0}, elem_id="camera-3d-control")
|
| 346 |
+
gr.Markdown("### 🎚️ Slider Controls")
|
| 347 |
+
azimuth_slider = gr.Slider(label="Azimuth", minimum=0, maximum=315, step=45, value=0, info="0°=front, 90°=right, 180°=back, 270°=left")
|
| 348 |
+
elevation_slider = gr.Slider(label="Elevation", minimum=-30, maximum=60, step=30, value=0, info="-30°=low angle, 0°=eye, 60°=high angle")
|
| 349 |
+
distance_slider = gr.Slider(label="Distance", minimum=0.6, maximum=1.4, step=0.4, value=1.0, info="0.6=close, 1.0=medium, 1.4=wide")
|
| 350 |
+
"""
|
| 351 |
+
if "camera_container:" not in content:
|
| 352 |
+
content = content.replace(
|
| 353 |
+
' input_image_2 = gr.Image(label="Upload Reference (Image 2)", type="pil", height=290, visible=False)',
|
| 354 |
+
camera_ui_block.strip("\\n")
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# Inject the Events. We place them right before "run_button.click("
|
| 358 |
+
camera_events = """
|
| 359 |
+
# --- 3D Camera Events ---
|
| 360 |
+
def update_prompt_from_sliders(az, el, dist, curr_prompt):
|
| 361 |
+
return update_prompt_with_camera(az, el, dist, curr_prompt)
|
| 362 |
+
|
| 363 |
+
def sync_3d_to_sliders(cv, curr_prompt):
|
| 364 |
+
if cv and isinstance(cv, dict):
|
| 365 |
+
az = cv.get('azimuth', 0)
|
| 366 |
+
el = cv.get('elevation', 0)
|
| 367 |
+
dist = cv.get('distance', 1.0)
|
| 368 |
+
return az, el, dist, update_prompt_with_camera(az, el, dist, curr_prompt)
|
| 369 |
+
return gr.update(), gr.update(), gr.update(), gr.update()
|
| 370 |
+
|
| 371 |
+
def sync_sliders_to_3d(az, el, dist):
|
| 372 |
+
return {"azimuth": az, "elevation": el, "distance": dist}
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def update_3d_image(img):
|
| 376 |
+
if img is None: return gr.update(imageUrl=None)
|
| 377 |
+
import base64
|
| 378 |
+
from io import BytesIO
|
| 379 |
+
buf = BytesIO()
|
| 380 |
+
img.save(buf, format="PNG")
|
| 381 |
+
durl = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
|
| 382 |
+
return gr.update(imageUrl=durl)
|
| 383 |
+
|
| 384 |
+
for slider in [azimuth_slider, elevation_slider, distance_slider]:
|
| 385 |
+
slider.change(fn=update_prompt_from_sliders, inputs=[azimuth_slider, elevation_slider, distance_slider, prompt], outputs=[prompt])
|
| 386 |
+
slider.release(fn=sync_sliders_to_3d, inputs=[azimuth_slider, elevation_slider, distance_slider], outputs=[camera_3d])
|
| 387 |
+
|
| 388 |
+
camera_3d.change(fn=sync_3d_to_sliders, inputs=[camera_3d, prompt], outputs=[azimuth_slider, elevation_slider, distance_slider, prompt])
|
| 389 |
+
|
| 390 |
+
input_image_1.upload(fn=update_3d_image, inputs=[input_image_1], outputs=[camera_3d])
|
| 391 |
+
input_image_1.clear(fn=lambda: gr.update(imageUrl=None), outputs=[camera_3d])
|
| 392 |
+
|
| 393 |
+
run_button.click(
|
| 394 |
+
"""
|
| 395 |
+
if "def sync_3d_to_sliders" not in content:
|
| 396 |
+
content = content.replace(" run_button.click(\n", camera_events)
|
| 397 |
+
|
| 398 |
+
# Clear any bad \\n literals if they exist
|
| 399 |
+
content = content.replace("\\n demo.queue", "\n demo.queue")
|
| 400 |
+
|
| 401 |
+
if "head=" not in content:
|
| 402 |
+
content = content.replace(
|
| 403 |
+
"demo.queue(max_size=30).launch(",
|
| 404 |
+
"""head = '<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>'
|
| 405 |
+
demo.queue(max_size=30).launch(head=head, """
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
with open(app_path, "w") as f:
|
| 409 |
+
f.write(content)
|
| 410 |
+
# --- END 3D CAMERA PATCHES ---
|
| 411 |
+
|
| 412 |
+
print("Successfully patched app.py.")
|
| 413 |
+
|
| 414 |
+
def install_dependencies():
|
| 415 |
+
"""Installs dependencies from requirements.txt into the persistent venv."""
|
| 416 |
+
pip_path = os.path.join(VENV_DIR, "bin", "pip")
|
| 417 |
+
requirements_path = os.path.join(REPO_DIR, "requirements.txt")
|
| 418 |
+
|
| 419 |
+
if os.path.exists(requirements_path):
|
| 420 |
+
print("Installing dependencies from requirements.txt...")
|
| 421 |
+
# Note: torch 2.9.1 might not exist on PyPI, checking if it needs --extra-index-url
|
| 422 |
+
# For L40S, we typically want the latest stable torch with CUDA 12.x
|
| 423 |
+
run_command(f"{pip_path} install -r {requirements_path}")
|
| 424 |
+
else:
|
| 425 |
+
print(f"No requirements.txt found in {REPO_DIR}")
|
| 426 |
+
|
| 427 |
+
def run_app():
|
| 428 |
+
"""Starts the Gradio app."""
|
| 429 |
+
python_path = os.path.join(VENV_DIR, "bin", "python")
|
| 430 |
+
app_path = os.path.join(REPO_DIR, "app.py")
|
| 431 |
+
|
| 432 |
+
if os.path.exists(app_path):
|
| 433 |
+
print(f"Starting app: {app_path}")
|
| 434 |
+
# Gradio apps often need to be bound to 0.0.0.0 for external access
|
| 435 |
+
# We'll run it and see if it requires specific environment variables
|
| 436 |
+
env = {"PYTHONPATH": REPO_DIR}
|
| 437 |
+
run_command(f"{python_path} {app_path}", cwd=REPO_DIR, env=env)
|
| 438 |
+
else:
|
| 439 |
+
print(f"App file not found: {app_path}")
|
| 440 |
+
|
| 441 |
+
def main():
|
| 442 |
+
# Ensure workspace exists
|
| 443 |
+
if not os.path.exists(WORKSPACE_DIR):
|
| 444 |
+
print(f"Error: {WORKSPACE_DIR} not found. Ensure this is a RunPod with persistent storage.")
|
| 445 |
+
return
|
| 446 |
+
|
| 447 |
+
ensure_dirs()
|
| 448 |
+
setup_venv()
|
| 449 |
+
install_git_xet()
|
| 450 |
+
install_hf_cli()
|
| 451 |
+
download_space()
|
| 452 |
+
patch_app()
|
| 453 |
+
install_dependencies()
|
| 454 |
+
|
| 455 |
+
# We don't call run_app here by default to allow script updates
|
| 456 |
+
print("Setup tasks completed. Run with 'run' argument to start the app.")
|
| 457 |
+
|
| 458 |
+
if __name__ == "__main__":
|
| 459 |
+
if len(sys.argv) > 1 and sys.argv[1] == "run":
|
| 460 |
+
run_app()
|
| 461 |
+
else:
|
| 462 |
+
main()
|
start_app.sh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
export PYTHONPATH=/workspace/Qwen-Image-Edit
|
| 3 |
+
export TMPDIR=/workspace/tmp
|
| 4 |
+
export HF_HOME=/workspace/cache/huggingface
|
| 5 |
+
export PYTHONUNBUFFERED=1
|
| 6 |
+
cd /workspace/Qwen-Image-Edit
|
| 7 |
+
exec /workspace/venv/bin/python -u /workspace/Qwen-Image-Edit/app.py
|