diff --git a/.gitattributes b/.gitattributes
index c7d9f3332a950355d5a77d85000f05e6f45435ea..d72a6d7a69050db0cfb95dead9f4a58348e99c1f 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+results/00003_out.png filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..fb5be37ca33c1dd7dc4200105d705164ff9b5196
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,28 @@
+BSD 3-Clause License
+
+Copyright (c) 2023, Mension
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/README.md b/README.md
index 86894d3ff4bf6116ff7fb87e05b276d576bd021e..741493ee91125bd57cd76624aacfd5f3c6b01944 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,2 @@
----
-title: Real ESRGAN Enhanced Anime Diffusion
-emoji: 🚀
-colorFrom: gray
-colorTo: indigo
-sdk: gradio
-sdk_version: 3.16.1
-app_file: app.py
-pinned: false
-license: bsd-3-clause-clear
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Real-ESRGAN-Enhanced-Anime-Diffusion
+Generate high resolution and quality anime pictures from texts or existed images.
diff --git a/README_HG.md b/README_HG.md
new file mode 100644
index 0000000000000000000000000000000000000000..99a0776d1a4669fa8387cc77e162c60084100a92
--- /dev/null
+++ b/README_HG.md
@@ -0,0 +1,12 @@
+---
+title: Anything V3.0
+emoji: 🏃
+colorFrom: gray
+colorTo: yellow
+sdk: gradio
+sdk_version: 3.10.1
+app_file: app.py
+pinned: false
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/RealESRGANv030/.github/workflows/no-response.yml b/RealESRGANv030/.github/workflows/no-response.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fa702eeacff13fe8475b0e102a8b8c37602f3963
--- /dev/null
+++ b/RealESRGANv030/.github/workflows/no-response.yml
@@ -0,0 +1,33 @@
+name: No Response
+
+# TODO: it seems not to work
+# Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml
+
+# **What it does**: Closes issues that don't have enough information to be actionable.
+# **Why we have it**: To remove the need for maintainers to remember to check back on issues periodically
+# to see if contributors have responded.
+# **Who does it impact**: Everyone that works on docs or docs-internal.
+
+on:
+ issue_comment:
+ types: [created]
+
+ schedule:
+ # Schedule for five minutes after the hour every hour
+ - cron: '5 * * * *'
+
+jobs:
+ noResponse:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: lee-dohm/no-response@v0.5.0
+ with:
+ token: ${{ github.token }}
+ closeComment: >
+ This issue has been automatically closed because there has been no response
+ to our request for more information from the original author. With only the
+ information that is currently in the issue, we don't have enough information
+ to take action. Please reach out if you have or find the answers we need so
+ that we can investigate further.
+ If you still have questions, please improve your description and re-open it.
+ Thanks :-)
diff --git a/RealESRGANv030/.github/workflows/publish-pip.yml b/RealESRGANv030/.github/workflows/publish-pip.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f3c8e574fd59fa9a4f3925eee9ee590dbdca965a
--- /dev/null
+++ b/RealESRGANv030/.github/workflows/publish-pip.yml
@@ -0,0 +1,33 @@
+name: PyPI Publish
+
+on: push
+
+jobs:
+ build-n-publish:
+ runs-on: ubuntu-latest
+ if: startsWith(github.event.ref, 'refs/tags')
+
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python 3.8
+ uses: actions/setup-python@v1
+ with:
+ python-version: 3.8
+ - name: Upgrade pip
+ run: pip install pip --upgrade
+ - name: Install PyTorch (cpu)
+ run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ - name: Install dependencies
+ run: |
+ pip install basicsr
+ pip install facexlib
+ pip install gfpgan
+ pip install -r requirements.txt
+ - name: Build and install
+ run: rm -rf .eggs && pip install -e .
+ - name: Build for distribution
+ run: python setup.py sdist bdist_wheel
+ - name: Publish distribution to PyPI
+ uses: pypa/gh-action-pypi-publish@master
+ with:
+ password: ${{ secrets.PYPI_API_TOKEN }}
diff --git a/RealESRGANv030/.github/workflows/pylint.yml b/RealESRGANv030/.github/workflows/pylint.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2084d1aa236b948d8734b6762d3e01054580001a
--- /dev/null
+++ b/RealESRGANv030/.github/workflows/pylint.yml
@@ -0,0 +1,31 @@
+name: PyLint
+
+on: [push, pull_request]
+
+jobs:
+ build:
+
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: [3.8]
+
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install codespell flake8 isort yapf
+
+ # modify the folders accordingly
+ - name: Lint
+ run: |
+ codespell
+ flake8 .
+ isort --check-only --diff realesrgan/ scripts/ inference_realesrgan.py setup.py
+ yapf -r -d realesrgan/ scripts/ inference_realesrgan.py setup.py
diff --git a/RealESRGANv030/.github/workflows/release.yml b/RealESRGANv030/.github/workflows/release.yml
new file mode 100644
index 0000000000000000000000000000000000000000..18be9e5c31768ab3be3e1075500a35bcb5783434
--- /dev/null
+++ b/RealESRGANv030/.github/workflows/release.yml
@@ -0,0 +1,41 @@
+name: release
+on:
+ push:
+ tags:
+ - '*'
+
+jobs:
+ build:
+ permissions: write-all
+ name: Create Release
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v2
+ - name: Create Release
+ id: create_release
+ uses: actions/create-release@v1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ tag_name: ${{ github.ref }}
+ release_name: Real-ESRGAN ${{ github.ref }} Release Note
+ body: |
+ 🚀 See you again 😸
+ 🚀Have a nice day 😸 and happy everyday 😃
+ 🚀 Long time no see ☄️
+
+ ✨ **Highlights**
+ ✅ [Features] Support ...
+
+ 🐛 **Bug Fixes**
+
+ 🌴 **Improvements**
+
+ 📢📢📢
+
+
+
+
+ draft: true
+ prerelease: false
diff --git a/RealESRGANv030/.gitignore b/RealESRGANv030/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..bb86ed0fd8a71305c7d8cc794bfa4591a5ccbc99
--- /dev/null
+++ b/RealESRGANv030/.gitignore
@@ -0,0 +1,140 @@
+# ignored folders
+datasets/*
+experiments/*
+results/*
+tb_logger/*
+wandb/*
+tmp/*
+weights/*
+
+version.py
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
diff --git a/RealESRGANv030/.pre-commit-config.yaml b/RealESRGANv030/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d221d29fbaac74bef1c0cd910ce8d8b6526181b8
--- /dev/null
+++ b/RealESRGANv030/.pre-commit-config.yaml
@@ -0,0 +1,46 @@
+repos:
+ # flake8
+ - repo: https://github.com/PyCQA/flake8
+ rev: 3.8.3
+ hooks:
+ - id: flake8
+ args: ["--config=setup.cfg", "--ignore=W504, W503"]
+
+ # modify known_third_party
+ - repo: https://github.com/asottile/seed-isort-config
+ rev: v2.2.0
+ hooks:
+ - id: seed-isort-config
+
+ # isort
+ - repo: https://github.com/timothycrosley/isort
+ rev: 5.2.2
+ hooks:
+ - id: isort
+
+ # yapf
+ - repo: https://github.com/pre-commit/mirrors-yapf
+ rev: v0.30.0
+ hooks:
+ - id: yapf
+
+ # codespell
+ - repo: https://github.com/codespell-project/codespell
+ rev: v2.1.0
+ hooks:
+ - id: codespell
+
+ # pre-commit-hooks
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v3.2.0
+ hooks:
+ - id: trailing-whitespace # Trim trailing whitespace
+ - id: check-yaml # Attempt to load all yaml files to verify syntax
+ - id: check-merge-conflict # Check for files that contain merge conflict strings
+ - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings
+ - id: end-of-file-fixer # Make sure files end in a newline and only a newline
+ - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0
+ - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*-
+ args: ["--remove"]
+ - id: mixed-line-ending # Replace or check mixed line ending
+ args: ["--fix=lf"]
diff --git a/RealESRGANv030/.vscode/settings.json b/RealESRGANv030/.vscode/settings.json
new file mode 100644
index 0000000000000000000000000000000000000000..b12635534688a8a8c69033d81fad96ef734ea6bb
--- /dev/null
+++ b/RealESRGANv030/.vscode/settings.json
@@ -0,0 +1,19 @@
+{
+ "files.trimTrailingWhitespace": true,
+ "editor.wordWrap": "on",
+ "editor.rulers": [
+ 80,
+ 120
+ ],
+ "editor.renderWhitespace": "all",
+ "editor.renderControlCharacters": true,
+ "python.formatting.provider": "yapf",
+ "python.formatting.yapfArgs": [
+ "--style",
+ "{BASED_ON_STYLE = pep8, BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true, SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true, COLUMN_LIMIT = 120}"
+ ],
+ "python.linting.flake8Enabled": true,
+ "python.linting.flake8Args": [
+ "max-line-length=120"
+ ],
+}
diff --git a/RealESRGANv030/CODE_OF_CONDUCT.md b/RealESRGANv030/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..e8cc4daa4345590464314889b187d6a2d7a8e20f
--- /dev/null
+++ b/RealESRGANv030/CODE_OF_CONDUCT.md
@@ -0,0 +1,128 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our
+community a harassment-free experience for everyone, regardless of age, body
+size, visible or invisible disability, ethnicity, sex characteristics, gender
+identity and expression, level of experience, education, socio-economic status,
+nationality, personal appearance, race, religion, or sexual identity
+and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming,
+diverse, inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our
+community include:
+
+* Demonstrating empathy and kindness toward other people
+* Being respectful of differing opinions, viewpoints, and experiences
+* Giving and gracefully accepting constructive feedback
+* Accepting responsibility and apologizing to those affected by our mistakes,
+ and learning from the experience
+* Focusing on what is best not just for us as individuals, but for the
+ overall community
+
+Examples of unacceptable behavior include:
+
+* The use of sexualized language or imagery, and sexual attention or
+ advances of any kind
+* Trolling, insulting or derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or email
+ address, without their explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of
+acceptable behavior and will take appropriate and fair corrective action in
+response to any behavior that they deem inappropriate, threatening, offensive,
+or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject
+comments, commits, code, wiki edits, issues, and other contributions that are
+not aligned to this Code of Conduct, and will communicate reasons for moderation
+decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when
+an individual is officially representing the community in public spaces.
+Examples of representing our community include using an official e-mail address,
+posting via an official social media account, or acting as an appointed
+representative at an online or offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported to the community leaders responsible for enforcement at
+xintao.wang@outlook.com or xintaowang@tencent.com.
+All complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the
+reporter of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining
+the consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+**Community Impact**: Use of inappropriate language or other behavior deemed
+unprofessional or unwelcome in the community.
+
+**Consequence**: A private, written warning from community leaders, providing
+clarity around the nature of the violation and an explanation of why the
+behavior was inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+**Community Impact**: A violation through a single incident or series
+of actions.
+
+**Consequence**: A warning with consequences for continued behavior. No
+interaction with the people involved, including unsolicited interaction with
+those enforcing the Code of Conduct, for a specified period of time. This
+includes avoiding interactions in community spaces as well as external channels
+like social media. Violating these terms may lead to a temporary or
+permanent ban.
+
+### 3. Temporary Ban
+
+**Community Impact**: A serious violation of community standards, including
+sustained inappropriate behavior.
+
+**Consequence**: A temporary ban from any sort of interaction or public
+communication with the community for a specified period of time. No public or
+private interaction with the people involved, including unsolicited interaction
+with those enforcing the Code of Conduct, is allowed during this period.
+Violating these terms may lead to a permanent ban.
+
+### 4. Permanent Ban
+
+**Community Impact**: Demonstrating a pattern of violation of community
+standards, including sustained inappropriate behavior, harassment of an
+individual, or aggression toward or disparagement of classes of individuals.
+
+**Consequence**: A permanent ban from any sort of public interaction within
+the community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage],
+version 2.0, available at
+https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
+
+Community Impact Guidelines were inspired by [Mozilla's code of conduct
+enforcement ladder](https://github.com/mozilla/diversity).
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see the FAQ at
+https://www.contributor-covenant.org/faq. Translations are available at
+https://www.contributor-covenant.org/translations.
diff --git a/RealESRGANv030/LICENSE b/RealESRGANv030/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..552a1eeaf01f4e7077013ed3496600c608f35202
--- /dev/null
+++ b/RealESRGANv030/LICENSE
@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2021, Xintao Wang
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/RealESRGANv030/MANIFEST.in b/RealESRGANv030/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..b87c827c894c82b5530c1267ea1d57e86c5f515b
--- /dev/null
+++ b/RealESRGANv030/MANIFEST.in
@@ -0,0 +1,8 @@
+include assets/*
+include inputs/*
+include scripts/*.py
+include inference_realesrgan.py
+include VERSION
+include LICENSE
+include requirements.txt
+include weights/README.md
diff --git a/RealESRGANv030/README.md b/RealESRGANv030/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..118e930c12bffd9e6da1df03180f5c9a8dcaabc3
--- /dev/null
+++ b/RealESRGANv030/README.md
@@ -0,0 +1,272 @@
+
+
+
+
+##
+
+
+
+👀[**Demos**](#-demos-videos) **|** 🚩[**Updates**](#-updates) **|** ⚡[**Usage**](#-quick-inference) **|** 🏰[**Model Zoo**](docs/model_zoo.md) **|** 🔧[Install](#-dependencies-and-installation) **|** 💻[Train](docs/Training.md) **|** ❓[FAQ](docs/FAQ.md) **|** 🎨[Contribution](docs/CONTRIBUTING.md)
+
+[](https://github.com/xinntao/Real-ESRGAN/releases)
+[](https://pypi.org/project/realesrgan/)
+[](https://github.com/xinntao/Real-ESRGAN/issues)
+[](https://github.com/xinntao/Real-ESRGAN/issues)
+[](https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE)
+[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/pylint.yml)
+[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/publish-pip.yml)
+
+
+
+🔥 **AnimeVideo-v3 model (动漫视频小模型)**. Please see [[*anime video models*](docs/anime_video_model.md)] and [[*comparisons*](docs/anime_comparisons.md)]
+🔥 **RealESRGAN_x4plus_anime_6B** for anime images **(动漫插图模型)**. Please see [[*anime_model*](docs/anime_model.md)]
+
+
+1. :boom: **Update** online Replicate demo: [](https://replicate.com/xinntao/realesrgan)
+1. Online Colab demo for Real-ESRGAN: [](https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing) **|** Online Colab demo for for Real-ESRGAN (**anime videos**): [](https://colab.research.google.com/drive/1yNl9ORUxxlL4N0keJa2SEPB61imPQd1B?usp=sharing)
+1. Portable [Windows](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-windows.zip) / [Linux](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-ubuntu.zip) / [MacOS](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-macos.zip) **executable files for Intel/AMD/Nvidia GPU**. You can find more information [here](#portable-executable-files-ncnn). The ncnn implementation is in [Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan)
+
+
+Real-ESRGAN aims at developing **Practical Algorithms for General Image/Video Restoration**.
+We extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data.
+
+🌌 Thanks for your valuable feedbacks/suggestions. All the feedbacks are updated in [feedback.md](docs/feedback.md).
+
+---
+
+If Real-ESRGAN is helpful, please help to ⭐ this repo or recommend it to your friends 😊
+Other recommended projects:
+▶️ [GFPGAN](https://github.com/TencentARC/GFPGAN): A practical algorithm for real-world face restoration
+▶️ [BasicSR](https://github.com/xinntao/BasicSR): An open-source image and video restoration toolbox
+▶️ [facexlib](https://github.com/xinntao/facexlib): A collection that provides useful face-relation functions.
+▶️ [HandyView](https://github.com/xinntao/HandyView): A PyQt5-based image viewer that is handy for view and comparison
+▶️ [HandyFigure](https://github.com/xinntao/HandyFigure): Open source of paper figures
+
+---
+
+### 📖 Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
+
+> [[Paper](https://arxiv.org/abs/2107.10833)] [[YouTube Video](https://www.youtube.com/watch?v=fxHWoDSSvSc)] [[B站讲解](https://www.bilibili.com/video/BV1H34y1m7sS/)] [[Poster](https://xinntao.github.io/projects/RealESRGAN_src/RealESRGAN_poster.pdf)] [[PPT slides](https://docs.google.com/presentation/d/1QtW6Iy8rm8rGLsJ0Ldti6kP-7Qyzy6XL/edit?usp=sharing&ouid=109799856763657548160&rtpof=true&sd=true)]
+> [Xintao Wang](https://xinntao.github.io/), Liangbin Xie, [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en)
+> [Tencent ARC Lab](https://arc.tencent.com/en/ai-demos/imgRestore); Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences
+
+
+
+
+
+---
+
+
+## 🚩 Updates
+
+- ✅ Add the **realesr-general-x4v3** model - a tiny small model for general scenes. It also supports the **--dn** option to balance the noise (avoiding over-smooth results). **--dn** is short for denoising strength.
+- ✅ Update the **RealESRGAN AnimeVideo-v3** model. Please see [anime video models](docs/anime_video_model.md) and [comparisons](docs/anime_comparisons.md) for more details.
+- ✅ Add small models for anime videos. More details are in [anime video models](docs/anime_video_model.md).
+- ✅ Add the ncnn implementation [Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan).
+- ✅ Add [*RealESRGAN_x4plus_anime_6B.pth*](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth), which is optimized for **anime** images with much smaller model size. More details and comparisons with [waifu2x](https://github.com/nihui/waifu2x-ncnn-vulkan) are in [**anime_model.md**](docs/anime_model.md)
+- ✅ Support finetuning on your own data or paired data (*i.e.*, finetuning ESRGAN). See [here](docs/Training.md#Finetune-Real-ESRGAN-on-your-own-dataset)
+- ✅ Integrate [GFPGAN](https://github.com/TencentARC/GFPGAN) to support **face enhancement**.
+- ✅ Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Real-ESRGAN). Thanks [@AK391](https://github.com/AK391)
+- ✅ Support arbitrary scale with `--outscale` (It actually further resizes outputs with `LANCZOS4`). Add *RealESRGAN_x2plus.pth* model.
+- ✅ [The inference code](inference_realesrgan.py) supports: 1) **tile** options; 2) images with **alpha channel**; 3) **gray** images; 4) **16-bit** images.
+- ✅ The training codes have been released. A detailed guide can be found in [Training.md](docs/Training.md).
+
+---
+
+
+## 👀 Demos Videos
+
+#### Bilibili
+
+- [大闹天宫片段](https://www.bilibili.com/video/BV1ja41117zb)
+- [Anime dance cut 动漫魔性舞蹈](https://www.bilibili.com/video/BV1wY4y1L7hT/)
+- [海贼王片段](https://www.bilibili.com/video/BV1i3411L7Gy/)
+
+#### YouTube
+
+## 🔧 Dependencies and Installation
+
+- Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
+- [PyTorch >= 1.7](https://pytorch.org/)
+
+### Installation
+
+1. Clone repo
+
+ ```bash
+ git clone https://github.com/xinntao/Real-ESRGAN.git
+ cd Real-ESRGAN
+ ```
+
+1. Install dependent packages
+
+ ```bash
+ # Install basicsr - https://github.com/xinntao/BasicSR
+ # We use BasicSR for both training and inference
+ pip install basicsr
+ # facexlib and gfpgan are for face enhancement
+ pip install facexlib
+ pip install gfpgan
+ pip install -r requirements.txt
+ python setup.py develop
+ ```
+
+---
+
+## ⚡ Quick Inference
+
+There are usually three ways to inference Real-ESRGAN.
+
+1. [Online inference](#online-inference)
+1. [Portable executable files (NCNN)](#portable-executable-files-ncnn)
+1. [Python script](#python-script)
+
+### Online inference
+
+1. You can try in our website: [ARC Demo](https://arc.tencent.com/en/ai-demos/imgRestore) (now only support RealESRGAN_x4plus_anime_6B)
+1. [Colab Demo](https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing) for Real-ESRGAN **|** [Colab Demo](https://colab.research.google.com/drive/1yNl9ORUxxlL4N0keJa2SEPB61imPQd1B?usp=sharing) for Real-ESRGAN (**anime videos**).
+
+### Portable executable files (NCNN)
+
+You can download [Windows](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-windows.zip) / [Linux](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-ubuntu.zip) / [MacOS](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-macos.zip) **executable files for Intel/AMD/Nvidia GPU**.
+
+This executable file is **portable** and includes all the binaries and models required. No CUDA or PyTorch environment is needed.
+
+You can simply run the following command (the Windows example, more information is in the README.md of each executable files):
+
+```bash
+./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png -n model_name
+```
+
+We have provided five models:
+
+1. realesrgan-x4plus (default)
+2. realesrnet-x4plus
+3. realesrgan-x4plus-anime (optimized for anime images, small model size)
+4. realesr-animevideov3 (animation video)
+
+You can use the `-n` argument for other models, for example, `./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png -n realesrnet-x4plus`
+
+#### Usage of portable executable files
+
+1. Please refer to [Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan#computer-usages) for more details.
+1. Note that it does not support all the functions (such as `outscale`) as the python script `inference_realesrgan.py`.
+
+```console
+Usage: realesrgan-ncnn-vulkan.exe -i infile -o outfile [options]...
+
+ -h show this help
+ -i input-path input image path (jpg/png/webp) or directory
+ -o output-path output image path (jpg/png/webp) or directory
+ -s scale upscale ratio (can be 2, 3, 4. default=4)
+ -t tile-size tile size (>=32/0=auto, default=0) can be 0,0,0 for multi-gpu
+ -m model-path folder path to the pre-trained models. default=models
+ -n model-name model name (default=realesr-animevideov3, can be realesr-animevideov3 | realesrgan-x4plus | realesrgan-x4plus-anime | realesrnet-x4plus)
+ -g gpu-id gpu device to use (default=auto) can be 0,1,2 for multi-gpu
+ -j load:proc:save thread count for load/proc/save (default=1:2:2) can be 1:2,2,2:2 for multi-gpu
+ -x enable tta mode"
+ -f format output image format (jpg/png/webp, default=ext/png)
+ -v verbose output
+```
+
+Note that it may introduce block inconsistency (and also generate slightly different results from the PyTorch implementation), because this executable file first crops the input image into several tiles, and then processes them separately, finally stitches together.
+
+### Python script
+
+#### Usage of python script
+
+1. You can use X4 model for **arbitrary output size** with the argument `outscale`. The program will further perform cheap resize operation after the Real-ESRGAN output.
+
+```console
+Usage: python inference_realesrgan.py -n RealESRGAN_x4plus -i infile -o outfile [options]...
+
+A common command: python inference_realesrgan.py -n RealESRGAN_x4plus -i infile --outscale 3.5 --face_enhance
+
+ -h show this help
+ -i --input Input image or folder. Default: inputs
+ -o --output Output folder. Default: results
+ -n --model_name Model name. Default: RealESRGAN_x4plus
+ -s, --outscale The final upsampling scale of the image. Default: 4
+ --suffix Suffix of the restored image. Default: out
+ -t, --tile Tile size, 0 for no tile during testing. Default: 0
+ --face_enhance Whether to use GFPGAN to enhance face. Default: False
+ --fp32 Use fp32 precision during inference. Default: fp16 (half precision).
+ --ext Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto
+```
+
+#### Inference general images
+
+Download pre-trained models: [RealESRGAN_x4plus.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)
+
+```bash
+wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P weights
+```
+
+Inference!
+
+```bash
+python inference_realesrgan.py -n RealESRGAN_x4plus -i inputs --face_enhance
+```
+
+Results are in the `results` folder
+
+#### Inference anime images
+
+
+
+
+
+Pre-trained models: [RealESRGAN_x4plus_anime_6B](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth)
+ More details and comparisons with [waifu2x](https://github.com/nihui/waifu2x-ncnn-vulkan) are in [**anime_model.md**](docs/anime_model.md)
+
+```bash
+# download model
+wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P weights
+# inference
+python inference_realesrgan.py -n RealESRGAN_x4plus_anime_6B -i inputs
+```
+
+Results are in the `results` folder
+
+---
+
+## BibTeX
+
+ @InProceedings{wang2021realesrgan,
+ author = {Xintao Wang and Liangbin Xie and Chao Dong and Ying Shan},
+ title = {Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
+ booktitle = {International Conference on Computer Vision Workshops (ICCVW)},
+ date = {2021}
+ }
+
+## 📧 Contact
+
+If you have any question, please email `xintao.wang@outlook.com` or `xintaowang@tencent.com`.
+
+
+## 🧩 Projects that use Real-ESRGAN
+
+If you develop/use Real-ESRGAN in your projects, welcome to let me know.
+
+- NCNN-Android: [RealSR-NCNN-Android](https://github.com/tumuyan/RealSR-NCNN-Android) by [tumuyan](https://github.com/tumuyan)
+- VapourSynth: [vs-realesrgan](https://github.com/HolyWu/vs-realesrgan) by [HolyWu](https://github.com/HolyWu)
+- NCNN: [Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan)
+
+ **GUI**
+
+- [Waifu2x-Extension-GUI](https://github.com/AaronFeng753/Waifu2x-Extension-GUI) by [AaronFeng753](https://github.com/AaronFeng753)
+- [Squirrel-RIFE](https://github.com/Justin62628/Squirrel-RIFE) by [Justin62628](https://github.com/Justin62628)
+- [Real-GUI](https://github.com/scifx/Real-GUI) by [scifx](https://github.com/scifx)
+- [Real-ESRGAN_GUI](https://github.com/net2cn/Real-ESRGAN_GUI) by [net2cn](https://github.com/net2cn)
+- [Real-ESRGAN-EGUI](https://github.com/WGzeyu/Real-ESRGAN-EGUI) by [WGzeyu](https://github.com/WGzeyu)
+- [anime_upscaler](https://github.com/shangar21/anime_upscaler) by [shangar21](https://github.com/shangar21)
+- [Upscayl](https://github.com/upscayl/upscayl) by [Nayam Amarshe](https://github.com/NayamAmarshe) and [TGS963](https://github.com/TGS963)
+
+## 🤗 Acknowledgement
+
+Thanks for all the contributors.
+
+- [AK391](https://github.com/AK391): Integrate RealESRGAN to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Real-ESRGAN).
+- [Asiimoviet](https://github.com/Asiimoviet): Translate the README.md to Chinese (中文).
+- [2ji3150](https://github.com/2ji3150): Thanks for the [detailed and valuable feedbacks/suggestions](https://github.com/xinntao/Real-ESRGAN/issues/131).
+- [Jared-02](https://github.com/Jared-02): Translate the Training.md to Chinese (中文).
diff --git a/RealESRGANv030/README_CN.md b/RealESRGANv030/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..fda1217bec600c5dcea72624c13533be6b71453e
--- /dev/null
+++ b/RealESRGANv030/README_CN.md
@@ -0,0 +1,276 @@
+
+
+
+
+##
+
+[](https://github.com/xinntao/Real-ESRGAN/releases)
+[](https://pypi.org/project/realesrgan/)
+[](https://github.com/xinntao/Real-ESRGAN/issues)
+[](https://github.com/xinntao/Real-ESRGAN/issues)
+[](https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE)
+[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/pylint.yml)
+[](https://github.com/xinntao/Real-ESRGAN/blob/master/.github/workflows/publish-pip.yml)
+
+:fire: 更新动漫视频的小模型 **RealESRGAN AnimeVideo-v3**. 更多信息在 [[动漫视频模型介绍](docs/anime_video_model.md)] 和 [[比较](docs/anime_comparisons_CN.md)] 中.
+
+1. Real-ESRGAN的[Colab Demo](https://colab.research.google.com/drive/1k2Zod6kSHEvraybHl50Lys0LerhyTMCo?usp=sharing) | Real-ESRGAN**动漫视频** 的[Colab Demo](https://colab.research.google.com/drive/1yNl9ORUxxlL4N0keJa2SEPB61imPQd1B?usp=sharing)
+2. **支持Intel/AMD/Nvidia显卡**的绿色版exe文件: [Windows版](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-windows.zip) / [Linux版](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-ubuntu.zip) / [macOS版](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-macos.zip),详情请移步[这里](#便携版(绿色版)可执行文件)。NCNN的实现在 [Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan)。
+
+Real-ESRGAN 的目标是开发出**实用的图像/视频修复算法**。
+我们在 ESRGAN 的基础上使用纯合成的数据来进行训练,以使其能被应用于实际的图片修复的场景(顾名思义:Real-ESRGAN)。
+
+:art: Real-ESRGAN 需要,也很欢迎你的贡献,如新功能、模型、bug修复、建议、维护等等。详情可以查看[CONTRIBUTING.md](docs/CONTRIBUTING.md),所有的贡献者都会被列在[此处](README_CN.md#hugs-感谢)。
+
+:milky_way: 感谢大家提供了很好的反馈。这些反馈会逐步更新在 [这个文档](docs/feedback.md)。
+
+:question: 常见的问题可以在[FAQ.md](docs/FAQ.md)中找到答案。(好吧,现在还是空白的=-=||)
+
+---
+
+如果 Real-ESRGAN 对你有帮助,可以给本项目一个 Star :star: ,或者推荐给你的朋友们,谢谢!:blush:
+其他推荐的项目:
+:arrow_forward: [GFPGAN](https://github.com/TencentARC/GFPGAN): 实用的人脸复原算法
+:arrow_forward: [BasicSR](https://github.com/xinntao/BasicSR): 开源的图像和视频工具箱
+:arrow_forward: [facexlib](https://github.com/xinntao/facexlib): 提供与人脸相关的工具箱
+:arrow_forward: [HandyView](https://github.com/xinntao/HandyView): 基于PyQt5的图片查看器,方便查看以及比较
+
+---
+
+
+
+🚩更新
+
+- ✅ 更新动漫视频的小模型 **RealESRGAN AnimeVideo-v3**. 更多信息在 [anime video models](docs/anime_video_model.md) 和 [comparisons](docs/anime_comparisons.md)中.
+- ✅ 添加了针对动漫视频的小模型, 更多信息在 [anime video models](docs/anime_video_model.md) 中.
+- ✅ 添加了ncnn 实现:[Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan).
+- ✅ 添加了 [*RealESRGAN_x4plus_anime_6B.pth*](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth),对二次元图片进行了优化,并减少了model的大小。详情 以及 与[waifu2x](https://github.com/nihui/waifu2x-ncnn-vulkan)的对比请查看[**anime_model.md**](docs/anime_model.md)
+- ✅支持用户在自己的数据上进行微调 (finetune):[详情](docs/Training.md#Finetune-Real-ESRGAN-on-your-own-dataset)
+- ✅ 支持使用[GFPGAN](https://github.com/TencentARC/GFPGAN)**增强人脸**
+- ✅ 通过[Gradio](https://github.com/gradio-app/gradio)添加到了[Huggingface Spaces](https://huggingface.co/spaces)(一个机器学习应用的在线平台):[Gradio在线版](https://huggingface.co/spaces/akhaliq/Real-ESRGAN)。感谢[@AK391](https://github.com/AK391)
+- ✅ 支持任意比例的缩放:`--outscale`(实际上使用`LANCZOS4`来更进一步调整输出图像的尺寸)。添加了*RealESRGAN_x2plus.pth*模型
+- ✅ [推断脚本](inference_realesrgan.py)支持: 1) 分块处理**tile**; 2) 带**alpha通道**的图像; 3) **灰色**图像; 4) **16-bit**图像.
+- ✅ 训练代码已经发布,具体做法可查看:[Training.md](docs/Training.md)。
+
+
+
+
+
+🧩使用Real-ESRGAN的项目
+
+ 👋 如果你开发/使用/集成了Real-ESRGAN, 欢迎联系我添加
+
+- NCNN-Android: [RealSR-NCNN-Android](https://github.com/tumuyan/RealSR-NCNN-Android) by [tumuyan](https://github.com/tumuyan)
+- VapourSynth: [vs-realesrgan](https://github.com/HolyWu/vs-realesrgan) by [HolyWu](https://github.com/HolyWu)
+- NCNN: [Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan)
+
+ **易用的图形界面**
+
+- [Waifu2x-Extension-GUI](https://github.com/AaronFeng753/Waifu2x-Extension-GUI) by [AaronFeng753](https://github.com/AaronFeng753)
+- [Squirrel-RIFE](https://github.com/Justin62628/Squirrel-RIFE) by [Justin62628](https://github.com/Justin62628)
+- [Real-GUI](https://github.com/scifx/Real-GUI) by [scifx](https://github.com/scifx)
+- [Real-ESRGAN_GUI](https://github.com/net2cn/Real-ESRGAN_GUI) by [net2cn](https://github.com/net2cn)
+- [Real-ESRGAN-EGUI](https://github.com/WGzeyu/Real-ESRGAN-EGUI) by [WGzeyu](https://github.com/WGzeyu)
+- [anime_upscaler](https://github.com/shangar21/anime_upscaler) by [shangar21](https://github.com/shangar21)
+- [RealESRGAN-GUI](https://github.com/Baiyuetribe/paper2gui/blob/main/Video%20Super%20Resolution/RealESRGAN-GUI.md) by [Baiyuetribe](https://github.com/Baiyuetribe)
+
+
+
+
+👀Demo视频(B站)
+
+- [大闹天宫片段](https://www.bilibili.com/video/BV1ja41117zb)
+
+
+
+### :book: Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
+
+> [[论文](https://arxiv.org/abs/2107.10833)] [项目主页] [[YouTube 视频](https://www.youtube.com/watch?v=fxHWoDSSvSc)] [[B站视频](https://www.bilibili.com/video/BV1H34y1m7sS/)] [[Poster](https://xinntao.github.io/projects/RealESRGAN_src/RealESRGAN_poster.pdf)] [[PPT](https://docs.google.com/presentation/d/1QtW6Iy8rm8rGLsJ0Ldti6kP-7Qyzy6XL/edit?usp=sharing&ouid=109799856763657548160&rtpof=true&sd=true)]
+> [Xintao Wang](https://xinntao.github.io/), Liangbin Xie, [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en)
+> Tencent ARC Lab; Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences
+
+
+
+
+
+---
+
+我们提供了一套训练好的模型(*RealESRGAN_x4plus.pth*),可以进行4倍的超分辨率。
+**现在的 Real-ESRGAN 还是有几率失败的,因为现实生活的降质过程比较复杂。**
+而且,本项目对**人脸以及文字之类**的效果还不是太好,但是我们会持续进行优化的。
+
+Real-ESRGAN 将会被长期支持,我会在空闲的时间中持续维护更新。
+
+这些是未来计划的几个新功能:
+
+- [ ] 优化人脸
+- [ ] 优化文字
+- [x] 优化动画图像
+- [ ] 支持更多的超分辨率比例
+- [ ] 可调节的复原
+
+如果你有好主意或需求,欢迎在 issue 或 discussion 中提出。
+如果你有一些 Real-ESRGAN 中有问题的照片,你也可以在 issue 或者 discussion 中发出来。我会留意(但是不一定能解决:stuck_out_tongue:)。如果有必要的话,我还会专门开一页来记录那些有待解决的图像。
+
+---
+
+### 便携版(绿色版)可执行文件
+
+你可以下载**支持Intel/AMD/Nvidia显卡**的绿色版exe文件: [Windows版](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-windows.zip) / [Linux版](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-ubuntu.zip) / [macOS版](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-macos.zip)。
+
+绿色版指的是这些exe你可以直接运行(放U盘里拷走都没问题),因为里面已经有所需的文件和模型了。它不需要 CUDA 或者 PyTorch运行环境。
+
+你可以通过下面这个命令来运行(Windows版本的例子,更多信息请查看对应版本的README.md):
+
+```bash
+./realesrgan-ncnn-vulkan.exe -i 输入图像.jpg -o 输出图像.png -n 模型名字
+```
+
+我们提供了五种模型:
+
+1. realesrgan-x4plus(默认)
+2. reaesrnet-x4plus
+3. realesrgan-x4plus-anime(针对动漫插画图像优化,有更小的体积)
+4. realesr-animevideov3 (针对动漫视频)
+
+你可以通过`-n`参数来使用其他模型,例如`./realesrgan-ncnn-vulkan.exe -i 二次元图片.jpg -o 二刺螈图片.png -n realesrgan-x4plus-anime`
+
+### 可执行文件的用法
+
+1. 更多细节可以参考 [Real-ESRGAN-ncnn-vulkan](https://github.com/xinntao/Real-ESRGAN-ncnn-vulkan#computer-usages).
+2. 注意:可执行文件并没有支持 python 脚本 `inference_realesrgan.py` 中所有的功能,比如 `outscale` 选项) .
+
+```console
+Usage: realesrgan-ncnn-vulkan.exe -i infile -o outfile [options]...
+
+ -h show this help
+ -i input-path input image path (jpg/png/webp) or directory
+ -o output-path output image path (jpg/png/webp) or directory
+ -s scale upscale ratio (can be 2, 3, 4. default=4)
+ -t tile-size tile size (>=32/0=auto, default=0) can be 0,0,0 for multi-gpu
+ -m model-path folder path to the pre-trained models. default=models
+ -n model-name model name (default=realesr-animevideov3, can be realesr-animevideov3 | realesrgan-x4plus | realesrgan-x4plus-anime | realesrnet-x4plus)
+ -g gpu-id gpu device to use (default=auto) can be 0,1,2 for multi-gpu
+ -j load:proc:save thread count for load/proc/save (default=1:2:2) can be 1:2,2,2:2 for multi-gpu
+ -x enable tta mode"
+ -f format output image format (jpg/png/webp, default=ext/png)
+ -v verbose output
+```
+
+由于这些exe文件会把图像分成几个板块,然后来分别进行处理,再合成导出,输出的图像可能会有一点割裂感(而且可能跟PyTorch的输出不太一样)
+
+---
+
+## :wrench: 依赖以及安装
+
+- Python >= 3.7 (推荐使用[Anaconda](https://www.anaconda.com/download/#linux)或[Miniconda](https://docs.conda.io/en/latest/miniconda.html))
+- [PyTorch >= 1.7](https://pytorch.org/)
+
+#### 安装
+
+1. 把项目克隆到本地
+
+ ```bash
+ git clone https://github.com/xinntao/Real-ESRGAN.git
+ cd Real-ESRGAN
+ ```
+
+2. 安装各种依赖
+
+ ```bash
+ # 安装 basicsr - https://github.com/xinntao/BasicSR
+ # 我们使用BasicSR来训练以及推断
+ pip install basicsr
+ # facexlib和gfpgan是用来增强人脸的
+ pip install facexlib
+ pip install gfpgan
+ pip install -r requirements.txt
+ python setup.py develop
+ ```
+
+## :zap: 快速上手
+
+### 普通图片
+
+下载我们训练好的模型: [RealESRGAN_x4plus.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth)
+
+```bash
+wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P weights
+```
+
+推断!
+
+```bash
+python inference_realesrgan.py -n RealESRGAN_x4plus -i inputs --face_enhance
+```
+
+结果在`results`文件夹
+
+### 动画图片
+
+
+
+
+
+训练好的模型: [RealESRGAN_x4plus_anime_6B](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth)
+有关[waifu2x](https://github.com/nihui/waifu2x-ncnn-vulkan)的更多信息和对比在[**anime_model.md**](docs/anime_model.md)中。
+
+```bash
+# 下载模型
+wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P weights
+# 推断
+python inference_realesrgan.py -n RealESRGAN_x4plus_anime_6B -i inputs
+```
+
+结果在`results`文件夹
+
+### Python 脚本的用法
+
+1. 虽然你使用了 X4 模型,但是你可以 **输出任意尺寸比例的图片**,只要实用了 `outscale` 参数. 程序会进一步对模型的输出图像进行缩放。
+
+```console
+Usage: python inference_realesrgan.py -n RealESRGAN_x4plus -i infile -o outfile [options]...
+
+A common command: python inference_realesrgan.py -n RealESRGAN_x4plus -i infile --outscale 3.5 --face_enhance
+
+ -h show this help
+ -i --input Input image or folder. Default: inputs
+ -o --output Output folder. Default: results
+ -n --model_name Model name. Default: RealESRGAN_x4plus
+ -s, --outscale The final upsampling scale of the image. Default: 4
+ --suffix Suffix of the restored image. Default: out
+ -t, --tile Tile size, 0 for no tile during testing. Default: 0
+ --face_enhance Whether to use GFPGAN to enhance face. Default: False
+ --fp32 Whether to use half precision during inference. Default: False
+ --ext Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto
+```
+
+## :european_castle: 模型库
+
+请参见 [docs/model_zoo.md](docs/model_zoo.md)
+
+## :computer: 训练,在你的数据上微调(Fine-tune)
+
+这里有一份详细的指南:[Training.md](docs/Training.md).
+
+## BibTeX 引用
+
+ @Article{wang2021realesrgan,
+ title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
+ author={Xintao Wang and Liangbin Xie and Chao Dong and Ying Shan},
+ journal={arXiv:2107.10833},
+ year={2021}
+ }
+
+## :e-mail: 联系我们
+
+如果你有任何问题,请通过 `xintao.wang@outlook.com` 或 `xintaowang@tencent.com` 联系我们。
+
+## :hugs: 感谢
+
+感谢所有的贡献者大大们~
+
+- [AK391](https://github.com/AK391): 通过[Gradio](https://github.com/gradio-app/gradio)添加到了[Huggingface Spaces](https://huggingface.co/spaces)(一个机器学习应用的在线平台):[Gradio在线版](https://huggingface.co/spaces/akhaliq/Real-ESRGAN)。
+- [Asiimoviet](https://github.com/Asiimoviet): 把 README.md 文档 翻译成了中文。
+- [2ji3150](https://github.com/2ji3150): 感谢详尽并且富有价值的[反馈、建议](https://github.com/xinntao/Real-ESRGAN/issues/131).
+- [Jared-02](https://github.com/Jared-02): 把 Training.md 文档 翻译成了中文。
diff --git a/RealESRGANv030/VERSION b/RealESRGANv030/VERSION
new file mode 100644
index 0000000000000000000000000000000000000000..0d91a54c7d439e84e3dd17d3594f1b2b6737f430
--- /dev/null
+++ b/RealESRGANv030/VERSION
@@ -0,0 +1 @@
+0.3.0
diff --git a/RealESRGANv030/__init__.py b/RealESRGANv030/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/RealESRGANv030/assets/realesrgan_logo.png b/RealESRGANv030/assets/realesrgan_logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..88cd1ad6170794c2becb95006edffa0655d9372a
Binary files /dev/null and b/RealESRGANv030/assets/realesrgan_logo.png differ
diff --git a/RealESRGANv030/assets/realesrgan_logo_ai.png b/RealESRGANv030/assets/realesrgan_logo_ai.png
new file mode 100644
index 0000000000000000000000000000000000000000..b0f595cf2535de7e69393384d8d056300f1cdddc
Binary files /dev/null and b/RealESRGANv030/assets/realesrgan_logo_ai.png differ
diff --git a/RealESRGANv030/assets/realesrgan_logo_av.png b/RealESRGANv030/assets/realesrgan_logo_av.png
new file mode 100644
index 0000000000000000000000000000000000000000..501ac8e81292d9369122a69ec2dd56a3ae8beca6
Binary files /dev/null and b/RealESRGANv030/assets/realesrgan_logo_av.png differ
diff --git a/RealESRGANv030/assets/realesrgan_logo_gi.png b/RealESRGANv030/assets/realesrgan_logo_gi.png
new file mode 100644
index 0000000000000000000000000000000000000000..cdb0a1a74e0b54a1c684141324c6635acf2f60f8
Binary files /dev/null and b/RealESRGANv030/assets/realesrgan_logo_gi.png differ
diff --git a/RealESRGANv030/assets/realesrgan_logo_gv.png b/RealESRGANv030/assets/realesrgan_logo_gv.png
new file mode 100644
index 0000000000000000000000000000000000000000..21dfba05f3855f1d9740e6d2cbe2a8ac736f4508
Binary files /dev/null and b/RealESRGANv030/assets/realesrgan_logo_gv.png differ
diff --git a/RealESRGANv030/assets/teaser-text.png b/RealESRGANv030/assets/teaser-text.png
new file mode 100644
index 0000000000000000000000000000000000000000..af9b424e390bf454838d962f049db9bb5ef1064d
Binary files /dev/null and b/RealESRGANv030/assets/teaser-text.png differ
diff --git a/RealESRGANv030/assets/teaser.jpg b/RealESRGANv030/assets/teaser.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..dc9b7ccdf78e3c816b0b6ca567433b53253b2e1e
Binary files /dev/null and b/RealESRGANv030/assets/teaser.jpg differ
diff --git a/RealESRGANv030/cog.yaml b/RealESRGANv030/cog.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..daa6983934b6e186ecd0cf1d4e038acdb9910cbc
--- /dev/null
+++ b/RealESRGANv030/cog.yaml
@@ -0,0 +1,22 @@
+# This file is used for constructing replicate env
+image: "r8.im/tencentarc/realesrgan"
+
+build:
+ gpu: true
+ python_version: "3.8"
+ system_packages:
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ python_packages:
+ - "torch==1.7.1"
+ - "torchvision==0.8.2"
+ - "numpy==1.21.1"
+ - "lmdb==1.2.1"
+ - "opencv-python==4.5.3.56"
+ - "PyYAML==5.4.1"
+ - "tqdm==4.62.2"
+ - "yapf==0.31.0"
+ - "basicsr==1.4.2"
+ - "facexlib==0.2.5"
+
+predict: "cog_predict.py:Predictor"
diff --git a/RealESRGANv030/cog_predict.py b/RealESRGANv030/cog_predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa0f89dfda8e3ff14afd7b3b8544f04d86e96562
--- /dev/null
+++ b/RealESRGANv030/cog_predict.py
@@ -0,0 +1,148 @@
+# flake8: noqa
+# This file is used for deploying replicate models
+# running: cog predict -i img=@inputs/00017_gray.png -i version='General - v3' -i scale=2 -i face_enhance=True -i tile=0
+# push: cog push r8.im/xinntao/realesrgan
+
+import os
+
+os.system('pip install gfpgan')
+os.system('python setup.py develop')
+
+import cv2
+import shutil
+import tempfile
+import torch
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from basicsr.archs.srvgg_arch import SRVGGNetCompact
+
+from realesrgan.utils import RealESRGANer
+
+try:
+ from cog import BasePredictor, Input, Path
+ from gfpgan import GFPGANer
+except Exception:
+ print('please install cog and realesrgan package')
+
+
+class Predictor(BasePredictor):
+
+ def setup(self):
+ os.makedirs('output', exist_ok=True)
+ # download weights
+ if not os.path.exists('weights/realesr-general-x4v3.pth'):
+ os.system(
+ 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./weights'
+ )
+ if not os.path.exists('weights/GFPGANv1.4.pth'):
+ os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./weights')
+ if not os.path.exists('weights/RealESRGAN_x4plus.pth'):
+ os.system(
+ 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./weights'
+ )
+ if not os.path.exists('weights/RealESRGAN_x4plus_anime_6B.pth'):
+ os.system(
+ 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./weights'
+ )
+ if not os.path.exists('weights/realesr-animevideov3.pth'):
+ os.system(
+ 'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./weights'
+ )
+
+ def choose_model(self, scale, version, tile=0):
+ half = True if torch.cuda.is_available() else False
+ if version == 'General - RealESRGANplus':
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+ model_path = 'weights/RealESRGAN_x4plus.pth'
+ self.upsampler = RealESRGANer(
+ scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
+ elif version == 'General - v3':
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
+ model_path = 'weights/realesr-general-x4v3.pth'
+ self.upsampler = RealESRGANer(
+ scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
+ elif version == 'Anime - anime6B':
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+ model_path = 'weights/RealESRGAN_x4plus_anime_6B.pth'
+ self.upsampler = RealESRGANer(
+ scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
+ elif version == 'AnimeVideo - v3':
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
+ model_path = 'weights/realesr-animevideov3.pth'
+ self.upsampler = RealESRGANer(
+ scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
+
+ self.face_enhancer = GFPGANer(
+ model_path='weights/GFPGANv1.4.pth',
+ upscale=scale,
+ arch='clean',
+ channel_multiplier=2,
+ bg_upsampler=self.upsampler)
+
+ def predict(
+ self,
+ img: Path = Input(description='Input'),
+ version: str = Input(
+ description='RealESRGAN version. Please see [Readme] below for more descriptions',
+ choices=['General - RealESRGANplus', 'General - v3', 'Anime - anime6B', 'AnimeVideo - v3'],
+ default='General - v3'),
+ scale: float = Input(description='Rescaling factor', default=2),
+ face_enhance: bool = Input(
+ description='Enhance faces with GFPGAN. Note that it does not work for anime images/vidoes', default=False),
+ tile: int = Input(
+ description=
+ 'Tile size. Default is 0, that is no tile. When encountering the out-of-GPU-memory issue, please specify it, e.g., 400 or 200',
+ default=0)
+ ) -> Path:
+ if tile <= 100 or tile is None:
+ tile = 0
+ print(f'img: {img}. version: {version}. scale: {scale}. face_enhance: {face_enhance}. tile: {tile}.')
+ try:
+ extension = os.path.splitext(os.path.basename(str(img)))[1]
+ img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED)
+ if len(img.shape) == 3 and img.shape[2] == 4:
+ img_mode = 'RGBA'
+ elif len(img.shape) == 2:
+ img_mode = None
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ else:
+ img_mode = None
+
+ h, w = img.shape[0:2]
+ if h < 300:
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
+
+ self.choose_model(scale, version, tile)
+
+ try:
+ if face_enhance:
+ _, _, output = self.face_enhancer.enhance(
+ img, has_aligned=False, only_center_face=False, paste_back=True)
+ else:
+ output, _ = self.upsampler.enhance(img, outscale=scale)
+ except RuntimeError as error:
+ print('Error', error)
+ print('If you encounter CUDA out of memory, try to set "tile" to a smaller size, e.g., 400.')
+
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
+ extension = 'png'
+ # save_path = f'output/out.{extension}'
+ # cv2.imwrite(save_path, output)
+ out_path = Path(tempfile.mkdtemp()) / f'out.{extension}'
+ cv2.imwrite(str(out_path), output)
+ except Exception as error:
+ print('global exception: ', error)
+ finally:
+ clean_folder('output')
+ return out_path
+
+
+def clean_folder(folder):
+ for filename in os.listdir(folder):
+ file_path = os.path.join(folder, filename)
+ try:
+ if os.path.isfile(file_path) or os.path.islink(file_path):
+ os.unlink(file_path)
+ elif os.path.isdir(file_path):
+ shutil.rmtree(file_path)
+ except Exception as e:
+ print(f'Failed to delete {file_path}. Reason: {e}')
diff --git a/RealESRGANv030/docs/CONTRIBUTING.md b/RealESRGANv030/docs/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..75990c2ce7545b72fb6ebad8295ca4895f437205
--- /dev/null
+++ b/RealESRGANv030/docs/CONTRIBUTING.md
@@ -0,0 +1,44 @@
+# Contributing to Real-ESRGAN
+
+:art: Real-ESRGAN needs your contributions. Any contributions are welcome, such as new features/models/typo fixes/suggestions/maintenance, *etc*. See [CONTRIBUTING.md](docs/CONTRIBUTING.md). All contributors are list [here](README.md#hugs-acknowledgement).
+
+We like open-source and want to develop practical algorithms for general image restoration. However, individual strength is limited. So, any kinds of contributions are welcome, such as:
+
+- New features
+- New models (your fine-tuned models)
+- Bug fixes
+- Typo fixes
+- Suggestions
+- Maintenance
+- Documents
+- *etc*
+
+## Workflow
+
+1. Fork and pull the latest Real-ESRGAN repository
+1. Checkout a new branch (do not use master branch for PRs)
+1. Commit your changes
+1. Create a PR
+
+**Note**:
+
+1. Please check the code style and linting
+ 1. The style configuration is specified in [setup.cfg](setup.cfg)
+ 1. If you use VSCode, the settings are configured in [.vscode/settings.json](.vscode/settings.json)
+1. Strongly recommend using `pre-commit hook`. It will check your code style and linting before your commit.
+ 1. In the root path of project folder, run `pre-commit install`
+ 1. The pre-commit configuration is listed in [.pre-commit-config.yaml](.pre-commit-config.yaml)
+1. Better to [open a discussion](https://github.com/xinntao/Real-ESRGAN/discussions) before large changes.
+ 1. Welcome to discuss :sunglasses:. I will try my best to join the discussion.
+
+## TODO List
+
+:zero: The most straightforward way of improving model performance is to fine-tune on some specific datasets.
+
+Here are some TODOs:
+
+- [ ] optimize for human faces
+- [ ] optimize for texts
+- [ ] support controllable restoration strength
+
+:one: There are also [several issues](https://github.com/xinntao/Real-ESRGAN/issues) that require helpers to improve. If you can help, please let me know :smile:
diff --git a/RealESRGANv030/docs/FAQ.md b/RealESRGANv030/docs/FAQ.md
new file mode 100644
index 0000000000000000000000000000000000000000..843f4dd847487066a1c7c105c7292e2de0bd5f1a
--- /dev/null
+++ b/RealESRGANv030/docs/FAQ.md
@@ -0,0 +1,10 @@
+# FAQ
+
+1. **Q: How to select models?**
+A: Please refer to [docs/model_zoo.md](docs/model_zoo.md)
+
+1. **Q: Can `face_enhance` be used for anime images/animation videos?**
+A: No, it can only be used for real faces. It is recommended not to use this option for anime images/animation videos to save GPU memory.
+
+1. **Q: Error "slow_conv2d_cpu" not implemented for 'Half'**
+A: In order to save GPU memory consumption and speed up inference, Real-ESRGAN uses half precision (fp16) during inference by default. However, some operators for half inference are not implemented in CPU mode. You need to add **`--fp32` option** for the commands. For example, `python inference_realesrgan.py -n RealESRGAN_x4plus.pth -i inputs --fp32`.
diff --git a/RealESRGANv030/docs/Training.md b/RealESRGANv030/docs/Training.md
new file mode 100644
index 0000000000000000000000000000000000000000..77da5ea5763f7a6ab291ebc28afb13be37df3f50
--- /dev/null
+++ b/RealESRGANv030/docs/Training.md
@@ -0,0 +1,271 @@
+# :computer: How to Train/Finetune Real-ESRGAN
+
+- [Train Real-ESRGAN](#train-real-esrgan)
+ - [Overview](#overview)
+ - [Dataset Preparation](#dataset-preparation)
+ - [Train Real-ESRNet](#Train-Real-ESRNet)
+ - [Train Real-ESRGAN](#Train-Real-ESRGAN)
+- [Finetune Real-ESRGAN on your own dataset](#Finetune-Real-ESRGAN-on-your-own-dataset)
+ - [Generate degraded images on the fly](#Generate-degraded-images-on-the-fly)
+ - [Use paired training data](#use-your-own-paired-data)
+
+[English](Training.md) **|** [简体中文](Training_CN.md)
+
+## Train Real-ESRGAN
+
+### Overview
+
+The training has been divided into two stages. These two stages have the same data synthesis process and training pipeline, except for the loss functions. Specifically,
+
+1. We first train Real-ESRNet with L1 loss from the pre-trained model ESRGAN.
+1. We then use the trained Real-ESRNet model as an initialization of the generator, and train the Real-ESRGAN with a combination of L1 loss, perceptual loss and GAN loss.
+
+### Dataset Preparation
+
+We use DF2K (DIV2K and Flickr2K) + OST datasets for our training. Only HR images are required.
+You can download from :
+
+1. DIV2K: http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
+2. Flickr2K: https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar
+3. OST: https://openmmlab.oss-cn-hangzhou.aliyuncs.com/datasets/OST_dataset.zip
+
+Here are steps for data preparation.
+
+#### Step 1: [Optional] Generate multi-scale images
+
+For the DF2K dataset, we use a multi-scale strategy, *i.e.*, we downsample HR images to obtain several Ground-Truth images with different scales.
+You can use the [scripts/generate_multiscale_DF2K.py](scripts/generate_multiscale_DF2K.py) script to generate multi-scale images.
+Note that this step can be omitted if you just want to have a fast try.
+
+```bash
+python scripts/generate_multiscale_DF2K.py --input datasets/DF2K/DF2K_HR --output datasets/DF2K/DF2K_multiscale
+```
+
+#### Step 2: [Optional] Crop to sub-images
+
+We then crop DF2K images into sub-images for faster IO and processing.
+This step is optional if your IO is enough or your disk space is limited.
+
+You can use the [scripts/extract_subimages.py](scripts/extract_subimages.py) script. Here is the example:
+
+```bash
+ python scripts/extract_subimages.py --input datasets/DF2K/DF2K_multiscale --output datasets/DF2K/DF2K_multiscale_sub --crop_size 400 --step 200
+```
+
+#### Step 3: Prepare a txt for meta information
+
+You need to prepare a txt file containing the image paths. The following are some examples in `meta_info_DF2Kmultiscale+OST_sub.txt` (As different users may have different sub-images partitions, this file is not suitable for your purpose and you need to prepare your own txt file):
+
+```txt
+DF2K_HR_sub/000001_s001.png
+DF2K_HR_sub/000001_s002.png
+DF2K_HR_sub/000001_s003.png
+...
+```
+
+You can use the [scripts/generate_meta_info.py](scripts/generate_meta_info.py) script to generate the txt file.
+You can merge several folders into one meta_info txt. Here is the example:
+
+```bash
+ python scripts/generate_meta_info.py --input datasets/DF2K/DF2K_HR datasets/DF2K/DF2K_multiscale --root datasets/DF2K datasets/DF2K --meta_info datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt
+```
+
+### Train Real-ESRNet
+
+1. Download pre-trained model [ESRGAN](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth) into `experiments/pretrained_models`.
+ ```bash
+ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth -P experiments/pretrained_models
+ ```
+1. Modify the content in the option file `options/train_realesrnet_x4plus.yml` accordingly:
+ ```yml
+ train:
+ name: DF2K+OST
+ type: RealESRGANDataset
+ dataroot_gt: datasets/DF2K # modify to the root path of your folder
+ meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
+ io_backend:
+ type: disk
+ ```
+1. If you want to perform validation during training, uncomment those lines and modify accordingly:
+ ```yml
+ # Uncomment these for validation
+ # val:
+ # name: validation
+ # type: PairedImageDataset
+ # dataroot_gt: path_to_gt
+ # dataroot_lq: path_to_lq
+ # io_backend:
+ # type: disk
+
+ ...
+
+ # Uncomment these for validation
+ # validation settings
+ # val:
+ # val_freq: !!float 5e3
+ # save_img: True
+
+ # metrics:
+ # psnr: # metric name, can be arbitrary
+ # type: calculate_psnr
+ # crop_border: 4
+ # test_y_channel: false
+ ```
+1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
+ ```bash
+ CUDA_VISIBLE_DEVICES=0,1,2,3 \
+ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug
+ ```
+
+ Train with **a single GPU** in the *debug* mode:
+ ```bash
+ python realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --debug
+ ```
+1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
+ ```bash
+ CUDA_VISIBLE_DEVICES=0,1,2,3 \
+ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume
+ ```
+
+ Train with **a single GPU**:
+ ```bash
+ python realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --auto_resume
+ ```
+
+### Train Real-ESRGAN
+
+1. After the training of Real-ESRNet, you now have the file `experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth`. If you need to specify the pre-trained path to other files, modify the `pretrain_network_g` value in the option file `train_realesrgan_x4plus.yml`.
+1. Modify the option file `train_realesrgan_x4plus.yml` accordingly. Most modifications are similar to those listed above.
+1. Before the formal training, you may run in the `--debug` mode to see whether everything is OK. We use four GPUs for training:
+ ```bash
+ CUDA_VISIBLE_DEVICES=0,1,2,3 \
+ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug
+ ```
+
+ Train with **a single GPU** in the *debug* mode:
+ ```bash
+ python realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --debug
+ ```
+1. The formal training. We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
+ ```bash
+ CUDA_VISIBLE_DEVICES=0,1,2,3 \
+ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
+ ```
+
+ Train with **a single GPU**:
+ ```bash
+ python realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --auto_resume
+ ```
+
+## Finetune Real-ESRGAN on your own dataset
+
+You can finetune Real-ESRGAN on your own dataset. Typically, the fine-tuning process can be divided into two cases:
+
+1. [Generate degraded images on the fly](#Generate-degraded-images-on-the-fly)
+1. [Use your own **paired** data](#Use-paired-training-data)
+
+### Generate degraded images on the fly
+
+Only high-resolution images are required. The low-quality images are generated with the degradation process described in Real-ESRGAN during training.
+
+**1. Prepare dataset**
+
+See [this section](#dataset-preparation) for more details.
+
+**2. Download pre-trained models**
+
+Download pre-trained models into `experiments/pretrained_models`.
+
+- *RealESRGAN_x4plus.pth*:
+ ```bash
+ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
+ ```
+
+- *RealESRGAN_x4plus_netD.pth*:
+ ```bash
+ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth -P experiments/pretrained_models
+ ```
+
+**3. Finetune**
+
+Modify [options/finetune_realesrgan_x4plus.yml](options/finetune_realesrgan_x4plus.yml) accordingly, especially the `datasets` part:
+
+```yml
+train:
+ name: DF2K+OST
+ type: RealESRGANDataset
+ dataroot_gt: datasets/DF2K # modify to the root path of your folder
+ meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # modify to your own generate meta info txt
+ io_backend:
+ type: disk
+```
+
+We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
+
+```bash
+CUDA_VISIBLE_DEVICES=0,1,2,3 \
+python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/finetune_realesrgan_x4plus.yml --launcher pytorch --auto_resume
+```
+
+Finetune with **a single GPU**:
+```bash
+python realesrgan/train.py -opt options/finetune_realesrgan_x4plus.yml --auto_resume
+```
+
+### Use your own paired data
+
+You can also finetune RealESRGAN with your own paired data. It is more similar to fine-tuning ESRGAN.
+
+**1. Prepare dataset**
+
+Assume that you already have two folders:
+
+- **gt folder** (Ground-truth, high-resolution images): *datasets/DF2K/DIV2K_train_HR_sub*
+- **lq folder** (Low quality, low-resolution images): *datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub*
+
+Then, you can prepare the meta_info txt file using the script [scripts/generate_meta_info_pairdata.py](scripts/generate_meta_info_pairdata.py):
+
+```bash
+python scripts/generate_meta_info_pairdata.py --input datasets/DF2K/DIV2K_train_HR_sub datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub --meta_info datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
+```
+
+**2. Download pre-trained models**
+
+Download pre-trained models into `experiments/pretrained_models`.
+
+- *RealESRGAN_x4plus.pth*
+ ```bash
+ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
+ ```
+
+- *RealESRGAN_x4plus_netD.pth*
+ ```bash
+ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth -P experiments/pretrained_models
+ ```
+
+**3. Finetune**
+
+Modify [options/finetune_realesrgan_x4plus_pairdata.yml](options/finetune_realesrgan_x4plus_pairdata.yml) accordingly, especially the `datasets` part:
+
+```yml
+train:
+ name: DIV2K
+ type: RealESRGANPairedDataset
+ dataroot_gt: datasets/DF2K # modify to the root path of your folder
+ dataroot_lq: datasets/DF2K # modify to the root path of your folder
+ meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt # modify to your own generate meta info txt
+ io_backend:
+ type: disk
+```
+
+We use four GPUs for training. We use the `--auto_resume` argument to automatically resume the training if necessary.
+
+```bash
+CUDA_VISIBLE_DEVICES=0,1,2,3 \
+python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/finetune_realesrgan_x4plus_pairdata.yml --launcher pytorch --auto_resume
+```
+
+Finetune with **a single GPU**:
+```bash
+python realesrgan/train.py -opt options/finetune_realesrgan_x4plus_pairdata.yml --auto_resume
+```
diff --git a/RealESRGANv030/docs/Training_CN.md b/RealESRGANv030/docs/Training_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..dabc3c5d97e134a2d551157c2dd03a629ec661bc
--- /dev/null
+++ b/RealESRGANv030/docs/Training_CN.md
@@ -0,0 +1,271 @@
+# :computer: 如何训练/微调 Real-ESRGAN
+
+- [训练 Real-ESRGAN](#训练-real-esrgan)
+ - [概述](#概述)
+ - [准备数据集](#准备数据集)
+ - [训练 Real-ESRNet 模型](#训练-real-esrnet-模型)
+ - [训练 Real-ESRGAN 模型](#训练-real-esrgan-模型)
+- [用自己的数据集微调 Real-ESRGAN](#用自己的数据集微调-real-esrgan)
+ - [动态生成降级图像](#动态生成降级图像)
+ - [使用已配对的数据](#使用已配对的数据)
+
+[English](Training.md) **|** [简体中文](Training_CN.md)
+
+## 训练 Real-ESRGAN
+
+### 概述
+
+训练分为两个步骤。除了 loss 函数外,这两个步骤拥有相同数据合成以及训练的一条龙流程。具体点说:
+
+1. 首先使用 L1 loss 训练 Real-ESRNet 模型,其中 L1 loss 来自预先训练的 ESRGAN 模型。
+
+2. 然后我们将 Real-ESRNet 模型作为生成器初始化,结合L1 loss、感知 loss、GAN loss 三者的参数对 Real-ESRGAN 进行训练。
+
+### 准备数据集
+
+我们使用 DF2K ( DIV2K 和 Flickr2K ) + OST 数据集进行训练。只需要HR图像!
+下面是网站链接:
+1. DIV2K: http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
+2. Flickr2K: https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar
+3. OST: https://openmmlab.oss-cn-hangzhou.aliyuncs.com/datasets/OST_dataset.zip
+
+以下是数据的准备步骤。
+
+#### 第1步:【可选】生成多尺寸图片
+
+针对 DF2K 数据集,我们使用多尺寸缩放策略,*换言之*,我们对 HR 图像进行下采样,就能获得多尺寸的标准参考(Ground-Truth)图像。
+您可以使用这个 [scripts/generate_multiscale_DF2K.py](scripts/generate_multiscale_DF2K.py) 脚本快速生成多尺寸的图像。
+注意:如果您只想简单试试,那么可以跳过此步骤。
+
+```bash
+python scripts/generate_multiscale_DF2K.py --input datasets/DF2K/DF2K_HR --output datasets/DF2K/DF2K_multiscale
+```
+
+#### 第2步:【可选】裁切为子图像
+
+我们可以将 DF2K 图像裁切为子图像,以加快 IO 和处理速度。
+如果你的 IO 够好或储存空间有限,那么此步骤是可选的。
+
+您可以使用脚本 [scripts/extract_subimages.py](scripts/extract_subimages.py)。这是使用示例:
+
+```bash
+ python scripts/extract_subimages.py --input datasets/DF2K/DF2K_multiscale --output datasets/DF2K/DF2K_multiscale_sub --crop_size 400 --step 200
+```
+
+#### 第3步:准备元信息 txt
+
+您需要准备一个包含图像路径的 txt 文件。下面是 `meta_info_DF2Kmultiscale+OST_sub.txt` 中的部分展示(由于各个用户可能有截然不同的子图像划分,这个文件不适合你的需求,你得准备自己的 txt 文件):
+
+```txt
+DF2K_HR_sub/000001_s001.png
+DF2K_HR_sub/000001_s002.png
+DF2K_HR_sub/000001_s003.png
+...
+```
+
+你可以使用该脚本 [scripts/generate_meta_info.py](scripts/generate_meta_info.py) 生成包含图像路径的 txt 文件。
+你还可以合并多个文件夹的图像路径到一个元信息(meta_info)txt。这是使用示例:
+
+```bash
+ python scripts/generate_meta_info.py --input datasets/DF2K/DF2K_HR, datasets/DF2K/DF2K_multiscale --root datasets/DF2K, datasets/DF2K --meta_info datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt
+```
+
+### 训练 Real-ESRNet 模型
+
+1. 下载预先训练的模型 [ESRGAN](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth),放到 `experiments/pretrained_models`目录下。
+ ```bash
+ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth -P experiments/pretrained_models
+ ```
+2. 相应地修改选项文件 `options/train_realesrnet_x4plus.yml` 中的内容:
+ ```yml
+ train:
+ name: DF2K+OST
+ type: RealESRGANDataset
+ dataroot_gt: datasets/DF2K # 修改为你的数据集文件夹根目录
+ meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # 修改为你自己生成的元信息txt
+ io_backend:
+ type: disk
+ ```
+3. 如果你想在训练过程中执行验证,就取消注释这些内容并进行相应的修改:
+ ```yml
+ # 取消注释这些以进行验证
+ # val:
+ # name: validation
+ # type: PairedImageDataset
+ # dataroot_gt: path_to_gt
+ # dataroot_lq: path_to_lq
+ # io_backend:
+ # type: disk
+
+ ...
+
+ # 取消注释这些以进行验证
+ # 验证设置
+ # val:
+ # val_freq: !!float 5e3
+ # save_img: True
+
+ # metrics:
+ # psnr: # 指标名称,可以是任意的
+ # type: calculate_psnr
+ # crop_border: 4
+ # test_y_channel: false
+ ```
+4. 正式训练之前,你可以用 `--debug` 模式检查是否正常运行。我们用了4个GPU进行训练:
+ ```bash
+ CUDA_VISIBLE_DEVICES=0,1,2,3 \
+ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --debug
+ ```
+
+ 用 **1个GPU** 训练的 debug 模式示例:
+ ```bash
+ python realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --debug
+ ```
+5. 正式训练开始。我们用了4个GPU进行训练。还可以使用参数 `--auto_resume` 在必要时自动恢复训练。
+ ```bash
+ CUDA_VISIBLE_DEVICES=0,1,2,3 \
+ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --launcher pytorch --auto_resume
+ ```
+
+ 用 **1个GPU** 训练:
+ ```bash
+ python realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --auto_resume
+ ```
+
+### 训练 Real-ESRGAN 模型
+
+1. 训练 Real-ESRNet 模型后,您得到了这个 `experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth` 文件。如果需要指定预训练路径到其他文件,请修改选项文件 `train_realesrgan_x4plus.yml` 中 `pretrain_network_g` 的值。
+1. 修改选项文件 `train_realesrgan_x4plus.yml` 的内容。大多数修改与上节提到的类似。
+1. 正式训练之前,你可以以 `--debug` 模式检查是否正常运行。我们使用了4个GPU进行训练:
+ ```bash
+ CUDA_VISIBLE_DEVICES=0,1,2,3 \
+ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --debug
+ ```
+
+ 用 **1个GPU** 训练的 debug 模式示例:
+ ```bash
+ python realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --debug
+ ```
+1. 正式训练开始。我们使用4个GPU进行训练。还可以使用参数 `--auto_resume` 在必要时自动恢复训练。
+ ```bash
+ CUDA_VISIBLE_DEVICES=0,1,2,3 \
+ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --launcher pytorch --auto_resume
+ ```
+
+ 用 **1个GPU** 训练:
+ ```bash
+ python realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --auto_resume
+ ```
+
+## 用自己的数据集微调 Real-ESRGAN
+
+你可以用自己的数据集微调 Real-ESRGAN。一般地,微调(Fine-Tune)程序可以分为两种类型:
+
+1. [动态生成降级图像](#动态生成降级图像)
+2. [使用**已配对**的数据](#使用已配对的数据)
+
+### 动态生成降级图像
+
+只需要高分辨率图像。在训练过程中,使用 Real-ESRGAN 描述的降级模型生成低质量图像。
+
+**1. 准备数据集**
+
+完整信息请参见[本节](#准备数据集)。
+
+**2. 下载预训练模型**
+
+下载预先训练的模型到 `experiments/pretrained_models` 目录下。
+
+- *RealESRGAN_x4plus.pth*:
+ ```bash
+ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
+ ```
+
+- *RealESRGAN_x4plus_netD.pth*:
+ ```bash
+ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth -P experiments/pretrained_models
+ ```
+
+**3. 微调**
+
+修改选项文件 [options/finetune_realesrgan_x4plus.yml](options/finetune_realesrgan_x4plus.yml) ,特别是 `datasets` 部分:
+
+```yml
+train:
+ name: DF2K+OST
+ type: RealESRGANDataset
+ dataroot_gt: datasets/DF2K # 修改为你的数据集文件夹根目录
+ meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # 修改为你自己生成的元信息txt
+ io_backend:
+ type: disk
+```
+
+我们使用4个GPU进行训练。还可以使用参数 `--auto_resume` 在必要时自动恢复训练。
+
+```bash
+CUDA_VISIBLE_DEVICES=0,1,2,3 \
+python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/finetune_realesrgan_x4plus.yml --launcher pytorch --auto_resume
+```
+
+用 **1个GPU** 训练:
+```bash
+python realesrgan/train.py -opt options/finetune_realesrgan_x4plus.yml --auto_resume
+```
+
+### 使用已配对的数据
+
+你还可以用自己已经配对的数据微调 RealESRGAN。这个过程更类似于微调 ESRGAN。
+
+**1. 准备数据集**
+
+假设你已经有两个文件夹(folder):
+
+- **gt folder**(标准参考,高分辨率图像):*datasets/DF2K/DIV2K_train_HR_sub*
+- **lq folder**(低质量,低分辨率图像):*datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub*
+
+然后,您可以使用脚本 [scripts/generate_meta_info_pairdata.py](scripts/generate_meta_info_pairdata.py) 生成元信息(meta_info)txt 文件。
+
+```bash
+python scripts/generate_meta_info_pairdata.py --input datasets/DF2K/DIV2K_train_HR_sub datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub --meta_info datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
+```
+
+**2. 下载预训练模型**
+
+下载预先训练的模型到 `experiments/pretrained_models` 目录下。
+
+- *RealESRGAN_x4plus.pth*:
+ ```bash
+ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
+ ```
+
+- *RealESRGAN_x4plus_netD.pth*:
+ ```bash
+ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth -P experiments/pretrained_models
+ ```
+
+**3. 微调**
+
+修改选项文件 [options/finetune_realesrgan_x4plus_pairdata.yml](options/finetune_realesrgan_x4plus_pairdata.yml) ,特别是 `datasets` 部分:
+
+```yml
+train:
+ name: DIV2K
+ type: RealESRGANPairedDataset
+ dataroot_gt: datasets/DF2K # 修改为你的 gt folder 文件夹根目录
+ dataroot_lq: datasets/DF2K # 修改为你的 lq folder 文件夹根目录
+ meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt # 修改为你自己生成的元信息txt
+ io_backend:
+ type: disk
+```
+
+我们使用4个GPU进行训练。还可以使用参数 `--auto_resume` 在必要时自动恢复训练。
+
+```bash
+CUDA_VISIBLE_DEVICES=0,1,2,3 \
+python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 realesrgan/train.py -opt options/finetune_realesrgan_x4plus_pairdata.yml --launcher pytorch --auto_resume
+```
+
+用 **1个GPU** 训练:
+```bash
+python realesrgan/train.py -opt options/finetune_realesrgan_x4plus_pairdata.yml --auto_resume
+```
diff --git a/RealESRGANv030/docs/anime_comparisons.md b/RealESRGANv030/docs/anime_comparisons.md
new file mode 100644
index 0000000000000000000000000000000000000000..09603bdc989bbf68b1f9f466acac5d8e442b8a01
--- /dev/null
+++ b/RealESRGANv030/docs/anime_comparisons.md
@@ -0,0 +1,66 @@
+# Comparisons among different anime models
+
+[English](anime_comparisons.md) **|** [简体中文](anime_comparisons_CN.md)
+
+## Update News
+
+- 2022/04/24: Release **AnimeVideo-v3**. We have made the following improvements:
+ - **better naturalness**
+ - **Fewer artifacts**
+ - **more faithful to the original colors**
+ - **better texture restoration**
+ - **better background restoration**
+
+## Comparisons
+
+We have compared our RealESRGAN-AnimeVideo-v3 with the following methods.
+Our RealESRGAN-AnimeVideo-v3 can achieve better results with faster inference speed.
+
+- [waifu2x](https://github.com/nihui/waifu2x-ncnn-vulkan) with the hyperparameters: `tile=0`, `noiselevel=2`
+- [Real-CUGAN](https://github.com/bilibili/ailab/tree/main/Real-CUGAN): we use the [20220227](https://github.com/bilibili/ailab/releases/tag/Real-CUGAN-add-faster-low-memory-mode) version, the hyperparameters are: `cache_mode=0`, `tile=0`, `alpha=1`.
+- our RealESRGAN-AnimeVideo-v3
+
+## Results
+
+You may need to **zoom in** for comparing details, or **click the image** to see in the full size. Please note that the images
+in the table below are the resized and cropped patches from the original images, you can download the original inputs and outputs from [Google Drive](https://drive.google.com/drive/folders/1bc_Hje1Nqop9NDkUvci2VACSjL7HZMRp?usp=sharing) .
+
+**More natural results, better background restoration**
+| Input | waifu2x | Real-CUGAN | RealESRGAN
AnimeVideo-v3 |
+| :---: | :---: | :---: | :---: |
+| |  |  |  |
+| |  |  |  |
+| |  |  |  |
+
+**Fewer artifacts, better detailed textures**
+| Input | waifu2x | Real-CUGAN | RealESRGAN
AnimeVideo-v3 |
+| :---: | :---: | :---: | :---: |
+| |  |  |  |
+| |  |  |  |
+| |  |  |  |
+| |  |  |  |
+
+**Other better results**
+| Input | waifu2x | Real-CUGAN | RealESRGAN
AnimeVideo-v3 |
+| :---: | :---: | :---: | :---: |
+| |  |  |  |
+| |  |  |  |
+|  |   |   |   |
+| |  |  |  |
+| |  |  |  |
+
+## Inference Speed
+
+### PyTorch
+
+Note that we only report the **model** time, and ignore the IO time.
+
+| GPU | Input Resolution | waifu2x | Real-CUGAN | RealESRGAN-AnimeVideo-v3
+| :---: | :---: | :---: | :---: | :---: |
+| V100 | 1921 x 1080 | - | 3.4 fps | **10.0** fps |
+| V100 | 1280 x 720 | - | 7.2 fps | **22.6** fps |
+| V100 | 640 x 480 | - | 24.4 fps | **65.9** fps |
+
+### ncnn
+
+- [ ] TODO
diff --git a/RealESRGANv030/docs/anime_comparisons_CN.md b/RealESRGANv030/docs/anime_comparisons_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..43ba58344ed9554d5b30e2815d1b7d4ab8bc503f
--- /dev/null
+++ b/RealESRGANv030/docs/anime_comparisons_CN.md
@@ -0,0 +1,68 @@
+# 动漫视频模型比较
+
+[English](anime_comparisons.md) **|** [简体中文](anime_comparisons_CN.md)
+
+## 更新
+
+- 2022/04/24: 发布 **AnimeVideo-v3**. 主要做了以下更新:
+ - **更自然**
+ - **更少瑕疵**
+ - **颜色保持得更好**
+ - **更好的纹理恢复**
+ - **虚化背景处理**
+
+## 比较
+
+我们将 RealESRGAN-AnimeVideo-v3 与以下方法进行了比较。我们的 RealESRGAN-AnimeVideo-v3 可以以更快的推理速度获得更好的结果。
+
+- [waifu2x](https://github.com/nihui/waifu2x-ncnn-vulkan). 超参数: `tile=0`, `noiselevel=2`
+- [Real-CUGAN](https://github.com/bilibili/ailab/tree/main/Real-CUGAN): 我们使用了[20220227](https://github.com/bilibili/ailab/releases/tag/Real-CUGAN-add-faster-low-memory-mode)版本, 超参: `cache_mode=0`, `tile=0`, `alpha=1`.
+- 我们的 RealESRGAN-AnimeVideo-v3
+
+## 结果
+
+您可能需要**放大**以比较详细信息, 或者**单击图像**以查看完整尺寸。 请注意下面表格的图片是从原图里裁剪patch并且resize后的结果,您可以从
+[Google Drive](https://drive.google.com/drive/folders/1bc_Hje1Nqop9NDkUvci2VACSjL7HZMRp?usp=sharing) 里下载原始的输入和输出。
+
+**更自然的结果,更好的虚化背景恢复**
+
+| 输入 | waifu2x | Real-CUGAN | RealESRGAN
AnimeVideo-v3 |
+| :---: | :---: | :---: | :---: |
+| |  |  |  |
+| |  |  |  |
+| |  |  |  |
+
+**更少瑕疵,更好的细节纹理**
+
+| 输入 | waifu2x | Real-CUGAN | RealESRGAN
AnimeVideo-v3 |
+| :---: | :---: | :---: | :---: |
+| |  |  |  |
+| |  |  |  |
+| |  |  |  |
+| |  |  |  |
+
+**其他更好的结果**
+
+| 输入 | waifu2x | Real-CUGAN | RealESRGAN
AnimeVideo-v3 |
+| :---: | :---: | :---: | :---: |
+| |  |  |  |
+| |  |  |  |
+|  |   |   |   |
+| |  |  |  |
+| |  |  |  |
+
+## 推理速度比较
+
+### PyTorch
+
+请注意,我们只报告了**模型推理**的时间, 而忽略了读写硬盘的时间.
+
+| GPU | 输入尺寸 | waifu2x | Real-CUGAN | RealESRGAN-AnimeVideo-v3
+| :---: | :---: | :---: | :---: | :---: |
+| V100 | 1921 x 1080 | - | 3.4 fps | **10.0** fps |
+| V100 | 1280 x 720 | - | 7.2 fps | **22.6** fps |
+| V100 | 640 x 480 | - | 24.4 fps | **65.9** fps |
+
+### ncnn
+
+- [ ] TODO
diff --git a/RealESRGANv030/docs/anime_model.md b/RealESRGANv030/docs/anime_model.md
new file mode 100644
index 0000000000000000000000000000000000000000..213328d92d0dbaeb188f8ef0f47192e74efeaccc
--- /dev/null
+++ b/RealESRGANv030/docs/anime_model.md
@@ -0,0 +1,68 @@
+# Anime Model
+
+:white_check_mark: We add [*RealESRGAN_x4plus_anime_6B.pth*](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth), which is optimized for **anime** images with much smaller model size.
+
+- [How to Use](#how-to-use)
+ - [PyTorch Inference](#pytorch-inference)
+ - [ncnn Executable File](#ncnn-executable-file)
+- [Comparisons with waifu2x](#comparisons-with-waifu2x)
+- [Comparisons with Sliding Bars](#comparisons-with-sliding-bars)
+
+
+
+
+
+The following is a video comparison with sliding bar. You may need to use the full-screen mode for better visual quality, as the original image is large; otherwise, you may encounter aliasing issue.
+
+
+
+## How to Use
+
+### PyTorch Inference
+
+Pre-trained models: [RealESRGAN_x4plus_anime_6B](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth)
+
+```bash
+# download model
+wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P weights
+# inference
+python inference_realesrgan.py -n RealESRGAN_x4plus_anime_6B -i inputs
+```
+
+### ncnn Executable File
+
+Download the latest portable [Windows](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-windows.zip) / [Linux](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-ubuntu.zip) / [MacOS](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-macos.zip) **executable files for Intel/AMD/Nvidia GPU**.
+
+Taking the Windows as example, run:
+
+```bash
+./realesrgan-ncnn-vulkan.exe -i input.jpg -o output.png -n realesrgan-x4plus-anime
+```
+
+## Comparisons with waifu2x
+
+We compare Real-ESRGAN-anime with [waifu2x](https://github.com/nihui/waifu2x-ncnn-vulkan). We use the `-n 2 -s 4` for waifu2x.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+## Comparisons with Sliding Bars
+
+The following are video comparisons with sliding bar. You may need to use the full-screen mode for better visual quality, as the original image is large; otherwise, you may encounter aliasing issue.
+
+
+
+
diff --git a/RealESRGANv030/docs/anime_video_model.md b/RealESRGANv030/docs/anime_video_model.md
new file mode 100644
index 0000000000000000000000000000000000000000..0ad5c85804c1f8636c3720a652b40bbd9df0fe2e
--- /dev/null
+++ b/RealESRGANv030/docs/anime_video_model.md
@@ -0,0 +1,136 @@
+# Anime Video Models
+
+:white_check_mark: We add small models that are optimized for anime videos :-)
+More comparisons can be found in [anime_comparisons.md](anime_comparisons.md)
+
+- [How to Use](#how-to-use)
+- [PyTorch Inference](#pytorch-inference)
+- [ncnn Executable File](#ncnn-executable-file)
+ - [Step 1: Use ffmpeg to extract frames from video](#step-1-use-ffmpeg-to-extract-frames-from-video)
+ - [Step 2: Inference with Real-ESRGAN executable file](#step-2-inference-with-real-esrgan-executable-file)
+ - [Step 3: Merge the enhanced frames back into a video](#step-3-merge-the-enhanced-frames-back-into-a-video)
+- [More Demos](#more-demos)
+
+| Models | Scale | Description |
+| ---------------------------------------------------------------------------------------------------------------------------------- | :---- | :----------------------------- |
+| [realesr-animevideov3](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth) | X4 1 | Anime video model with XS size |
+
+Note:
+1 This model can also be used for X1, X2, X3.
+
+---
+
+The following are some demos (best view in the full screen mode).
+
+
+
+
+
+
+
+## How to Use
+
+### PyTorch Inference
+
+```bash
+# download model
+wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P weights
+# single gpu and single process inference
+CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2
+# single gpu and multi process inference (you can use multi-processing to improve GPU utilization)
+CUDA_VISIBLE_DEVICES=0 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2 --num_process_per_gpu 2
+# multi gpu and multi process inference
+CUDA_VISIBLE_DEVICES=0,1,2,3 python inference_realesrgan_video.py -i inputs/video/onepiece_demo.mp4 -n realesr-animevideov3 -s 2 --suffix outx2 --num_process_per_gpu 2
+```
+
+```console
+Usage:
+--num_process_per_gpu The total number of process is num_gpu * num_process_per_gpu. The bottleneck of
+ the program lies on the IO, so the GPUs are usually not fully utilized. To alleviate
+ this issue, you can use multi-processing by setting this parameter. As long as it
+ does not exceed the CUDA memory
+--extract_frame_first If you encounter ffmpeg error when using multi-processing, you can turn this option on.
+```
+
+### NCNN Executable File
+
+#### Step 1: Use ffmpeg to extract frames from video
+
+```bash
+ffmpeg -i onepiece_demo.mp4 -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 tmp_frames/frame%08d.png
+```
+
+- Remember to create the folder `tmp_frames` ahead
+
+#### Step 2: Inference with Real-ESRGAN executable file
+
+1. Download the latest portable [Windows](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-windows.zip) / [Linux](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-ubuntu.zip) / [MacOS](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesrgan-ncnn-vulkan-20220424-macos.zip) **executable files for Intel/AMD/Nvidia GPU**
+
+1. Taking the Windows as example, run:
+
+ ```bash
+ ./realesrgan-ncnn-vulkan.exe -i tmp_frames -o out_frames -n realesr-animevideov3 -s 2 -f jpg
+ ```
+
+ - Remember to create the folder `out_frames` ahead
+
+#### Step 3: Merge the enhanced frames back into a video
+
+1. First obtain fps from input videos by
+
+ ```bash
+ ffmpeg -i onepiece_demo.mp4
+ ```
+
+ ```console
+ Usage:
+ -i input video path
+ ```
+
+ You will get the output similar to the following screenshot.
+
+
+
+
+
+2. Merge frames
+
+ ```bash
+ ffmpeg -r 23.98 -i out_frames/frame%08d.jpg -c:v libx264 -r 23.98 -pix_fmt yuv420p output.mp4
+ ```
+
+ ```console
+ Usage:
+ -i input video path
+ -c:v video encoder (usually we use libx264)
+ -r fps, remember to modify it to meet your needs
+ -pix_fmt pixel format in video
+ ```
+
+ If you also want to copy audio from the input videos, run:
+
+ ```bash
+ ffmpeg -r 23.98 -i out_frames/frame%08d.jpg -i onepiece_demo.mp4 -map 0:v:0 -map 1:a:0 -c:a copy -c:v libx264 -r 23.98 -pix_fmt yuv420p output_w_audio.mp4
+ ```
+
+ ```console
+ Usage:
+ -i input video path, here we use two input streams
+ -c:v video encoder (usually we use libx264)
+ -r fps, remember to modify it to meet your needs
+ -pix_fmt pixel format in video
+ ```
+
+## More Demos
+
+- Input video for One Piece:
+
+
+
+- Out video for One Piece
+
+
+
+**More comparisons**
+
+
diff --git a/RealESRGANv030/docs/feedback.md b/RealESRGANv030/docs/feedback.md
new file mode 100644
index 0000000000000000000000000000000000000000..c621ed05e9bc122a2ae6309eac61583ab9f35e7a
--- /dev/null
+++ b/RealESRGANv030/docs/feedback.md
@@ -0,0 +1,11 @@
+# Feedback 反馈
+
+## 动漫插画模型
+
+1. 视频处理不了: 目前的模型,不是针对视频的,所以视频效果很很不好。我们在探究针对视频的模型了
+1. 景深虚化有问题: 现在的模型把一些景深 和 特意的虚化 都复原了,感觉不好。这个后面我们会考虑把这个信息结合进入。一个简单的做法是识别景深和虚化,然后作为条件告诉神经网络,哪些地方复原强一些,哪些地方复原要弱一些
+1. 不可以调节: 像 Waifu2X 可以调节。可以根据自己的喜好,做调整,但是 Real-ESRGAN-anime 并不可以。导致有些恢复效果过了
+1. 把原来的风格改变了: 不同的动漫插画都有自己的风格,现在的 Real-ESRGAN-anime 倾向于恢复成一种风格(这是受到训练数据集影响的)。风格是动漫很重要的一个要素,所以要尽可能保持
+1. 模型太大: 目前的模型处理太慢,能够更快。这个我们有相关的工作在探究,希望能够尽快有结果,并应用到 Real-ESRGAN 这一系列的模型上
+
+Thanks for the [detailed and valuable feedbacks/suggestions](https://github.com/xinntao/Real-ESRGAN/issues/131) by [2ji3150](https://github.com/2ji3150).
diff --git a/RealESRGANv030/docs/model_zoo.md b/RealESRGANv030/docs/model_zoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..132cc514bac6b447addac8485e0622a834d34474
--- /dev/null
+++ b/RealESRGANv030/docs/model_zoo.md
@@ -0,0 +1,49 @@
+# :european_castle: Model Zoo
+
+- [For General Images](#for-general-images)
+- [For Anime Images](#for-anime-images)
+- [For Anime Videos](#for-anime-videos)
+
+---
+
+## For General Images
+
+| Models | Scale | Description |
+| ------------------------------------------------------------------------------------------------------------------------------- | :---- | :------------------------------------------- |
+| [RealESRGAN_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth) | X4 | X4 model for general images |
+| [RealESRGAN_x2plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth) | X2 | X2 model for general images |
+| [RealESRNet_x4plus](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth) | X4 | X4 model with MSE loss (over-smooth effects) |
+| [official ESRGAN_x4](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth) | X4 | official ESRGAN model |
+| [realesr-general-x4v3](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth) | X4 (can also be used for X1, X2, X3) | A tiny small model (consume much fewer GPU memory and time); not too strong deblur and denoise capacity |
+
+The following models are **discriminators**, which are usually used for fine-tuning.
+
+| Models | Corresponding model |
+| ---------------------------------------------------------------------------------------------------------------------- | :------------------ |
+| [RealESRGAN_x4plus_netD](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth) | RealESRGAN_x4plus |
+| [RealESRGAN_x2plus_netD](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x2plus_netD.pth) | RealESRGAN_x2plus |
+
+## For Anime Images / Illustrations
+
+| Models | Scale | Description |
+| ------------------------------------------------------------------------------------------------------------------------------ | :---- | :---------------------------------------------------------- |
+| [RealESRGAN_x4plus_anime_6B](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth) | X4 | Optimized for anime images; 6 RRDB blocks (smaller network) |
+
+The following models are **discriminators**, which are usually used for fine-tuning.
+
+| Models | Corresponding model |
+| ---------------------------------------------------------------------------------------------------------------------------------------- | :------------------------- |
+| [RealESRGAN_x4plus_anime_6B_netD](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B_netD.pth) | RealESRGAN_x4plus_anime_6B |
+
+## For Animation Videos
+
+| Models | Scale | Description |
+| ---------------------------------------------------------------------------------------------------------------------------------- | :---- | :----------------------------- |
+| [realesr-animevideov3](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth) | X41 | Anime video model with XS size |
+
+Note:
+1 This model can also be used for X1, X2, X3.
+
+The following models are **discriminators**, which are usually used for fine-tuning.
+
+TODO
diff --git a/RealESRGANv030/docs/ncnn_conversion.md b/RealESRGANv030/docs/ncnn_conversion.md
new file mode 100644
index 0000000000000000000000000000000000000000..e1785cd079ccbb6f0a5ddefe24f63bfe81ce9b21
--- /dev/null
+++ b/RealESRGANv030/docs/ncnn_conversion.md
@@ -0,0 +1,11 @@
+# Instructions on converting to NCNN models
+
+1. Convert to onnx model with `scripts/pytorch2onnx.py`. Remember to modify codes accordingly
+1. Convert onnx model to ncnn model
+ 1. `cd ncnn-master\ncnn\build\tools\onnx`
+ 1. `onnx2ncnn.exe realesrgan-x4.onnx realesrgan-x4-raw.param realesrgan-x4-raw.bin`
+1. Optimize ncnn model
+ 1. fp16 mode
+ 1. `cd ncnn-master\ncnn\build\tools`
+ 1. `ncnnoptimize.exe realesrgan-x4-raw.param realesrgan-x4-raw.bin realesrgan-x4.param realesrgan-x4.bin 1`
+1. Modify the blob name in `realesrgan-x4.param`: `data` and `output`
diff --git a/RealESRGANv030/experiments/pretrained_models/README.md b/RealESRGANv030/experiments/pretrained_models/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d0cc4afcbdd2c733f6b946bb86bd00baa90e8295
--- /dev/null
+++ b/RealESRGANv030/experiments/pretrained_models/README.md
@@ -0,0 +1 @@
+# Put downloaded pre-trained models here
diff --git a/RealESRGANv030/inference_realesrgan.py b/RealESRGANv030/inference_realesrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a8cc43addb2e8e94b9920cef109443c7f475241
--- /dev/null
+++ b/RealESRGANv030/inference_realesrgan.py
@@ -0,0 +1,166 @@
+import argparse
+import cv2
+import glob
+import os
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from basicsr.utils.download_util import load_file_from_url
+
+from realesrgan import RealESRGANer
+from realesrgan.archs.srvgg_arch import SRVGGNetCompact
+
+
+def main():
+ """Inference demo for Real-ESRGAN.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
+ parser.add_argument(
+ '-n',
+ '--model_name',
+ type=str,
+ default='RealESRGAN_x4plus',
+ help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus | '
+ 'realesr-animevideov3 | realesr-general-x4v3'))
+ parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
+ parser.add_argument(
+ '-dn',
+ '--denoise_strength',
+ type=float,
+ default=0.5,
+ help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. '
+ 'Only used for the realesr-general-x4v3 model'))
+ parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
+ parser.add_argument(
+ '--model_path', type=str, default=None, help='[Option] Model path. Usually, you do not need to specify it')
+ parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
+ parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
+ parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
+ parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
+ parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
+ parser.add_argument(
+ '--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
+ parser.add_argument(
+ '--alpha_upsampler',
+ type=str,
+ default='realesrgan',
+ help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
+ parser.add_argument(
+ '--ext',
+ type=str,
+ default='auto',
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
+ parser.add_argument(
+ '-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')
+
+ args = parser.parse_args()
+
+ # determine models according to model names
+ args.model_name = args.model_name.split('.')[0]
+ if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+ netscale = 4
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
+ elif args.model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+ netscale = 4
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
+ elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+ netscale = 4
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
+ elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+ netscale = 2
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
+ elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
+ netscale = 4
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
+ elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
+ netscale = 4
+ file_url = [
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
+ ]
+
+ # determine model paths
+ if args.model_path is not None:
+ model_path = args.model_path
+ else:
+ model_path = os.path.join('weights', args.model_name + '.pth')
+ if not os.path.isfile(model_path):
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+ for url in file_url:
+ # model_path will be updated
+ model_path = load_file_from_url(
+ url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
+
+ # use dni to control the denoise strength
+ dni_weight = None
+ if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
+ wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
+ model_path = [model_path, wdn_model_path]
+ dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
+
+ # restorer
+ upsampler = RealESRGANer(
+ scale=netscale,
+ model_path=model_path,
+ dni_weight=dni_weight,
+ model=model,
+ tile=args.tile,
+ tile_pad=args.tile_pad,
+ pre_pad=args.pre_pad,
+ half=not args.fp32,
+ gpu_id=args.gpu_id)
+
+ if args.face_enhance: # Use GFPGAN for face enhancement
+ from gfpgan import GFPGANer
+ face_enhancer = GFPGANer(
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
+ upscale=args.outscale,
+ arch='clean',
+ channel_multiplier=2,
+ bg_upsampler=upsampler)
+ os.makedirs(args.output, exist_ok=True)
+
+ if os.path.isfile(args.input):
+ paths = [args.input]
+ else:
+ paths = sorted(glob.glob(os.path.join(args.input, '*')))
+
+ for idx, path in enumerate(paths):
+ imgname, extension = os.path.splitext(os.path.basename(path))
+ print('Testing', idx, imgname)
+
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ if len(img.shape) == 3 and img.shape[2] == 4:
+ img_mode = 'RGBA'
+ else:
+ img_mode = None
+
+ try:
+ if args.face_enhance:
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
+ else:
+ output, _ = upsampler.enhance(img, outscale=args.outscale)
+ except RuntimeError as error:
+ print('Error', error)
+ print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
+ else:
+ if args.ext == 'auto':
+ extension = extension[1:]
+ else:
+ extension = args.ext
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
+ extension = 'png'
+ if args.suffix == '':
+ save_path = os.path.join(args.output, f'{imgname}.{extension}')
+ else:
+ save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
+ cv2.imwrite(save_path, output)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/RealESRGANv030/inference_realesrgan_video.py b/RealESRGANv030/inference_realesrgan_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3c4d1465d552f4724092fd1b8250de1f7b0f4e8
--- /dev/null
+++ b/RealESRGANv030/inference_realesrgan_video.py
@@ -0,0 +1,398 @@
+import argparse
+import cv2
+import glob
+import mimetypes
+import numpy as np
+import os
+import shutil
+import subprocess
+import torch
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from basicsr.utils.download_util import load_file_from_url
+from os import path as osp
+from tqdm import tqdm
+
+from realesrgan import RealESRGANer
+from realesrgan.archs.srvgg_arch import SRVGGNetCompact
+
+try:
+ import ffmpeg
+except ImportError:
+ import pip
+ pip.main(['install', '--user', 'ffmpeg-python'])
+ import ffmpeg
+
+
+def get_video_meta_info(video_path):
+ ret = {}
+ probe = ffmpeg.probe(video_path)
+ video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
+ has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
+ ret['width'] = video_streams[0]['width']
+ ret['height'] = video_streams[0]['height']
+ ret['fps'] = eval(video_streams[0]['avg_frame_rate'])
+ ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None
+ ret['nb_frames'] = int(video_streams[0]['nb_frames'])
+ return ret
+
+
+def get_sub_video(args, num_process, process_idx):
+ if num_process == 1:
+ return args.input
+ meta = get_video_meta_info(args.input)
+ duration = int(meta['nb_frames'] / meta['fps'])
+ part_time = duration // num_process
+ print(f'duration: {duration}, part_time: {part_time}')
+ os.makedirs(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'), exist_ok=True)
+ out_path = osp.join(args.output, f'{args.video_name}_inp_tmp_videos', f'{process_idx:03d}.mp4')
+ cmd = [
+ args.ffmpeg_bin, f'-i {args.input}', '-ss', f'{part_time * process_idx}',
+ f'-to {part_time * (process_idx + 1)}' if process_idx != num_process - 1 else '', '-async 1', out_path, '-y'
+ ]
+ print(' '.join(cmd))
+ subprocess.call(' '.join(cmd), shell=True)
+ return out_path
+
+
+class Reader:
+
+ def __init__(self, args, total_workers=1, worker_idx=0):
+ self.args = args
+ input_type = mimetypes.guess_type(args.input)[0]
+ self.input_type = 'folder' if input_type is None else input_type
+ self.paths = [] # for image&folder type
+ self.audio = None
+ self.input_fps = None
+ if self.input_type.startswith('video'):
+ video_path = get_sub_video(args, total_workers, worker_idx)
+ self.stream_reader = (
+ ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24',
+ loglevel='error').run_async(
+ pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
+ meta = get_video_meta_info(video_path)
+ self.width = meta['width']
+ self.height = meta['height']
+ self.input_fps = meta['fps']
+ self.audio = meta['audio']
+ self.nb_frames = meta['nb_frames']
+
+ else:
+ if self.input_type.startswith('image'):
+ self.paths = [args.input]
+ else:
+ paths = sorted(glob.glob(os.path.join(args.input, '*')))
+ tot_frames = len(paths)
+ num_frame_per_worker = tot_frames // total_workers + (1 if tot_frames % total_workers else 0)
+ self.paths = paths[num_frame_per_worker * worker_idx:num_frame_per_worker * (worker_idx + 1)]
+
+ self.nb_frames = len(self.paths)
+ assert self.nb_frames > 0, 'empty folder'
+ from PIL import Image
+ tmp_img = Image.open(self.paths[0])
+ self.width, self.height = tmp_img.size
+ self.idx = 0
+
+ def get_resolution(self):
+ return self.height, self.width
+
+ def get_fps(self):
+ if self.args.fps is not None:
+ return self.args.fps
+ elif self.input_fps is not None:
+ return self.input_fps
+ return 24
+
+ def get_audio(self):
+ return self.audio
+
+ def __len__(self):
+ return self.nb_frames
+
+ def get_frame_from_stream(self):
+ img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel
+ if not img_bytes:
+ return None
+ img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
+ return img
+
+ def get_frame_from_list(self):
+ if self.idx >= self.nb_frames:
+ return None
+ img = cv2.imread(self.paths[self.idx])
+ self.idx += 1
+ return img
+
+ def get_frame(self):
+ if self.input_type.startswith('video'):
+ return self.get_frame_from_stream()
+ else:
+ return self.get_frame_from_list()
+
+ def close(self):
+ if self.input_type.startswith('video'):
+ self.stream_reader.stdin.close()
+ self.stream_reader.wait()
+
+
+class Writer:
+
+ def __init__(self, args, audio, height, width, video_save_path, fps):
+ out_width, out_height = int(width * args.outscale), int(height * args.outscale)
+ if out_height > 2160:
+ print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
+ 'We highly recommend to decrease the outscale(aka, -s).')
+
+ if audio is not None:
+ self.stream_writer = (
+ ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{out_width}x{out_height}',
+ framerate=fps).output(
+ audio,
+ video_save_path,
+ pix_fmt='yuv420p',
+ vcodec='libx264',
+ loglevel='error',
+ acodec='copy').overwrite_output().run_async(
+ pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
+ else:
+ self.stream_writer = (
+ ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{out_width}x{out_height}',
+ framerate=fps).output(
+ video_save_path, pix_fmt='yuv420p', vcodec='libx264',
+ loglevel='error').overwrite_output().run_async(
+ pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin))
+
+ def write_frame(self, frame):
+ frame = frame.astype(np.uint8).tobytes()
+ self.stream_writer.stdin.write(frame)
+
+ def close(self):
+ self.stream_writer.stdin.close()
+ self.stream_writer.wait()
+
+
+def inference_video(args, video_save_path, device=None, total_workers=1, worker_idx=0):
+ # ---------------------- determine models according to model names ---------------------- #
+ args.model_name = args.model_name.split('.pth')[0]
+ if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+ netscale = 4
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
+ elif args.model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+ netscale = 4
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
+ elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+ netscale = 4
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
+ elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+ netscale = 2
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
+ elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
+ netscale = 4
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
+ elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
+ netscale = 4
+ file_url = [
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
+ ]
+
+ # ---------------------- determine model paths ---------------------- #
+ model_path = os.path.join('weights', args.model_name + '.pth')
+ if not os.path.isfile(model_path):
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+ for url in file_url:
+ # model_path will be updated
+ model_path = load_file_from_url(
+ url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
+
+ # use dni to control the denoise strength
+ dni_weight = None
+ if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
+ wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
+ model_path = [model_path, wdn_model_path]
+ dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
+
+ # restorer
+ upsampler = RealESRGANer(
+ scale=netscale,
+ model_path=model_path,
+ dni_weight=dni_weight,
+ model=model,
+ tile=args.tile,
+ tile_pad=args.tile_pad,
+ pre_pad=args.pre_pad,
+ half=not args.fp32,
+ device=device,
+ )
+
+ if 'anime' in args.model_name and args.face_enhance:
+ print('face_enhance is not supported in anime models, we turned this option off for you. '
+ 'if you insist on turning it on, please manually comment the relevant lines of code.')
+ args.face_enhance = False
+
+ if args.face_enhance: # Use GFPGAN for face enhancement
+ from gfpgan import GFPGANer
+ face_enhancer = GFPGANer(
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
+ upscale=args.outscale,
+ arch='clean',
+ channel_multiplier=2,
+ bg_upsampler=upsampler) # TODO support custom device
+ else:
+ face_enhancer = None
+
+ reader = Reader(args, total_workers, worker_idx)
+ audio = reader.get_audio()
+ height, width = reader.get_resolution()
+ fps = reader.get_fps()
+ writer = Writer(args, audio, height, width, video_save_path, fps)
+
+ pbar = tqdm(total=len(reader), unit='frame', desc='inference')
+ while True:
+ img = reader.get_frame()
+ if img is None:
+ break
+
+ try:
+ if args.face_enhance:
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
+ else:
+ output, _ = upsampler.enhance(img, outscale=args.outscale)
+ except RuntimeError as error:
+ print('Error', error)
+ print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
+ else:
+ writer.write_frame(output)
+
+ torch.cuda.synchronize(device)
+ pbar.update(1)
+
+ reader.close()
+ writer.close()
+
+
+def run(args):
+ args.video_name = osp.splitext(os.path.basename(args.input))[0]
+ video_save_path = osp.join(args.output, f'{args.video_name}_{args.suffix}.mp4')
+
+ if args.extract_frame_first:
+ tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames')
+ os.makedirs(tmp_frames_folder, exist_ok=True)
+ os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {tmp_frames_folder}/frame%08d.png')
+ args.input = tmp_frames_folder
+
+ num_gpus = torch.cuda.device_count()
+ num_process = num_gpus * args.num_process_per_gpu
+ if num_process == 1:
+ inference_video(args, video_save_path)
+ return
+
+ ctx = torch.multiprocessing.get_context('spawn')
+ pool = ctx.Pool(num_process)
+ os.makedirs(osp.join(args.output, f'{args.video_name}_out_tmp_videos'), exist_ok=True)
+ pbar = tqdm(total=num_process, unit='sub_video', desc='inference')
+ for i in range(num_process):
+ sub_video_save_path = osp.join(args.output, f'{args.video_name}_out_tmp_videos', f'{i:03d}.mp4')
+ pool.apply_async(
+ inference_video,
+ args=(args, sub_video_save_path, torch.device(i % num_gpus), num_process, i),
+ callback=lambda arg: pbar.update(1))
+ pool.close()
+ pool.join()
+
+ # combine sub videos
+ # prepare vidlist.txt
+ with open(f'{args.output}/{args.video_name}_vidlist.txt', 'w') as f:
+ for i in range(num_process):
+ f.write(f'file \'{args.video_name}_out_tmp_videos/{i:03d}.mp4\'\n')
+
+ cmd = [
+ args.ffmpeg_bin, '-f', 'concat', '-safe', '0', '-i', f'{args.output}/{args.video_name}_vidlist.txt', '-c',
+ 'copy', f'{video_save_path}'
+ ]
+ print(' '.join(cmd))
+ subprocess.call(cmd)
+ shutil.rmtree(osp.join(args.output, f'{args.video_name}_out_tmp_videos'))
+ if osp.exists(osp.join(args.output, f'{args.video_name}_inp_tmp_videos')):
+ shutil.rmtree(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'))
+ os.remove(f'{args.output}/{args.video_name}_vidlist.txt')
+
+
+def main():
+ """Inference demo for Real-ESRGAN.
+ It mainly for restoring anime videos.
+
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-i', '--input', type=str, default='inputs', help='Input video, image or folder')
+ parser.add_argument(
+ '-n',
+ '--model_name',
+ type=str,
+ default='realesr-animevideov3',
+ help=('Model names: realesr-animevideov3 | RealESRGAN_x4plus_anime_6B | RealESRGAN_x4plus | RealESRNet_x4plus |'
+ ' RealESRGAN_x2plus | realesr-general-x4v3'
+ 'Default:realesr-animevideov3'))
+ parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
+ parser.add_argument(
+ '-dn',
+ '--denoise_strength',
+ type=float,
+ default=0.5,
+ help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. '
+ 'Only used for the realesr-general-x4v3 model'))
+ parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
+ parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored video')
+ parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
+ parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
+ parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
+ parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
+ parser.add_argument(
+ '--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
+ parser.add_argument('--fps', type=float, default=None, help='FPS of the output video')
+ parser.add_argument('--ffmpeg_bin', type=str, default='ffmpeg', help='The path to ffmpeg')
+ parser.add_argument('--extract_frame_first', action='store_true')
+ parser.add_argument('--num_process_per_gpu', type=int, default=1)
+
+ parser.add_argument(
+ '--alpha_upsampler',
+ type=str,
+ default='realesrgan',
+ help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
+ parser.add_argument(
+ '--ext',
+ type=str,
+ default='auto',
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
+ args = parser.parse_args()
+
+ args.input = args.input.rstrip('/').rstrip('\\')
+ os.makedirs(args.output, exist_ok=True)
+
+ if mimetypes.guess_type(args.input)[0] is not None and mimetypes.guess_type(args.input)[0].startswith('video'):
+ is_video = True
+ else:
+ is_video = False
+
+ if is_video and args.input.endswith('.flv'):
+ mp4_path = args.input.replace('.flv', '.mp4')
+ os.system(f'ffmpeg -i {args.input} -codec copy {mp4_path}')
+ args.input = mp4_path
+
+ if args.extract_frame_first and not is_video:
+ args.extract_frame_first = False
+
+ run(args)
+
+ if args.extract_frame_first:
+ tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames')
+ shutil.rmtree(tmp_frames_folder)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/RealESRGANv030/inputs/00003.png b/RealESRGANv030/inputs/00003.png
new file mode 100644
index 0000000000000000000000000000000000000000..00cad23adf5d658caf03a0a2874f0c89d96c5ddc
Binary files /dev/null and b/RealESRGANv030/inputs/00003.png differ
diff --git a/RealESRGANv030/inputs/00017_gray.png b/RealESRGANv030/inputs/00017_gray.png
new file mode 100644
index 0000000000000000000000000000000000000000..79af68e8aa0f036211734b7271633d88b2fc8f0d
Binary files /dev/null and b/RealESRGANv030/inputs/00017_gray.png differ
diff --git a/RealESRGANv030/inputs/0014.jpg b/RealESRGANv030/inputs/0014.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f59554fe3143b3ffa27d6fcb04143124b4d0412b
Binary files /dev/null and b/RealESRGANv030/inputs/0014.jpg differ
diff --git a/RealESRGANv030/inputs/0030.jpg b/RealESRGANv030/inputs/0030.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..61868926af738046e984bcf652134e3ea9b958d9
Binary files /dev/null and b/RealESRGANv030/inputs/0030.jpg differ
diff --git a/RealESRGANv030/inputs/ADE_val_00000114.jpg b/RealESRGANv030/inputs/ADE_val_00000114.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b4d9c9067adbcdd153527cef2c0cab4cf40bbfa5
Binary files /dev/null and b/RealESRGANv030/inputs/ADE_val_00000114.jpg differ
diff --git a/RealESRGANv030/inputs/OST_009.png b/RealESRGANv030/inputs/OST_009.png
new file mode 100644
index 0000000000000000000000000000000000000000..10bbc831acb7065827a14eb7e0538312a8d6f3e2
Binary files /dev/null and b/RealESRGANv030/inputs/OST_009.png differ
diff --git a/RealESRGANv030/inputs/children-alpha.png b/RealESRGANv030/inputs/children-alpha.png
new file mode 100644
index 0000000000000000000000000000000000000000..41dcc3b6cc7a8a1b073f6dbe09d0c12e18c1b4b3
Binary files /dev/null and b/RealESRGANv030/inputs/children-alpha.png differ
diff --git a/RealESRGANv030/inputs/tree_alpha_16bit.png b/RealESRGANv030/inputs/tree_alpha_16bit.png
new file mode 100644
index 0000000000000000000000000000000000000000..ca7c2aac2c5c9cdaea66ecc8e06d6b43e3d8bf20
Binary files /dev/null and b/RealESRGANv030/inputs/tree_alpha_16bit.png differ
diff --git a/RealESRGANv030/inputs/video/onepiece_demo.mp4 b/RealESRGANv030/inputs/video/onepiece_demo.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..29b4e5246b19008885611c23921fe4423f17e43f
Binary files /dev/null and b/RealESRGANv030/inputs/video/onepiece_demo.mp4 differ
diff --git a/RealESRGANv030/inputs/wolf_gray.jpg b/RealESRGANv030/inputs/wolf_gray.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..614766bdbcaa3730a8191afcb9616305381245ea
Binary files /dev/null and b/RealESRGANv030/inputs/wolf_gray.jpg differ
diff --git a/RealESRGANv030/interface.py b/RealESRGANv030/interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d6e67aa0c8825ab668f07a800b4dc4503d2e38a
--- /dev/null
+++ b/RealESRGANv030/interface.py
@@ -0,0 +1,139 @@
+import cv2
+from PIL import Image
+import glob
+import os
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from basicsr.utils.download_util import load_file_from_url
+
+from realesrgan import RealESRGANer
+from realesrgan.archs.srvgg_arch import SRVGGNetCompact
+
+def realEsrgan(model_name="RealESRGAN_x4plus_anime_6B",
+ model_path = None,
+ input_dir = 'inputs',
+ output_dir = 'results',
+ denoise_strength = 0.5,
+ outscale = 4,
+ suffix = 'out',
+ tile = 200,
+ tile_pad = 10,
+ pre_pad = 0,
+ face_enhance = True,
+ alpha_upsampler = 'realsrgan',
+ out_ext = 'auto',
+ fp32 = True,
+ gpu_id = None,
+ ):
+
+ # determine models according to model names
+ model_name = model_name.split('.')[0]
+ if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+ netscale = 4
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
+ elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+ netscale = 4
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
+ elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+ netscale = 4
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
+ elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+ netscale = 2
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
+ elif model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
+ netscale = 4
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
+ elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
+ netscale = 4
+ file_url = [
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
+ ]
+
+ # determine model paths
+ if model_path is None:
+ model_path = os.path.join('weights', model_name + '.pth')
+ if not os.path.isfile(model_path):
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+ for url in file_url:
+ # model_path will be updated
+ model_path = load_file_from_url(
+ url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
+
+ # use dni to control the denoise strength
+ dni_weight = None
+ if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
+ wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
+ model_path = [model_path, wdn_model_path]
+ dni_weight = [denoise_strength, 1 - denoise_strength]
+
+ # restorer
+ upsampler = RealESRGANer(
+ scale=netscale,
+ model_path=model_path,
+ dni_weight=dni_weight,
+ model=model,
+ tile=tile,
+ tile_pad=tile_pad,
+ pre_pad=pre_pad,
+ half=not fp32,
+ gpu_id=gpu_id)
+
+ if face_enhance: # Use GFPGAN for face enhancement
+ from gfpgan import GFPGANer
+ face_enhancer = GFPGANer(
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
+ upscale=outscale,
+ arch='clean',
+ channel_multiplier=2,
+ bg_upsampler=upsampler)
+ os.makedirs(output_dir, exist_ok=True)
+
+ if os.path.isfile(input_dir):
+ paths = [input_dir]
+ else:
+ paths = sorted(glob.glob(os.path.join(input_dir, '*')))
+
+ Imgs = []
+ for idx, path in enumerate(paths):
+ imgname, extension = os.path.splitext(os.path.basename(path))
+ print('Enhancing the resolution:', idx, imgname)
+
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ if len(img.shape) == 3 and img.shape[2] == 4:
+ img_mode = 'RGBA'
+ else:
+ img_mode = None
+
+ try:
+ if face_enhance:
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
+ else:
+ output, _ = upsampler.enhance(img, outscale=outscale)
+ except RuntimeError as error:
+ print('Error', error)
+ print('If you encounter CUDA or RAM out of memory, try to set --tile with a smaller number.')
+ else:
+ if out_ext == 'auto':
+ extension = extension[1:]
+ else:
+ extension = out_ext
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
+ extension = 'png'
+ if suffix == '':
+ save_path = os.path.join(output_dir, f'{imgname}.{extension}')
+ else:
+ save_path = os.path.join(output_dir, f'{imgname}_{suffix}.{extension}')
+
+ cv2.imwrite(save_path, output)
+
+ img = Image.fromarray(output.astype('uint8'), 'RGB')
+ Imgs.append(img)
+
+ return Imgs
+
diff --git a/RealESRGANv030/options/finetune_realesrgan_x4plus.yml b/RealESRGANv030/options/finetune_realesrgan_x4plus.yml
new file mode 100644
index 0000000000000000000000000000000000000000..aa9806570025dce0a967ca0541a0ea497a57d6a9
--- /dev/null
+++ b/RealESRGANv030/options/finetune_realesrgan_x4plus.yml
@@ -0,0 +1,188 @@
+# general settings
+name: finetune_RealESRGANx4plus_400k
+model_type: RealESRGANModel
+scale: 4
+num_gpu: auto
+manual_seed: 0
+
+# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
+# USM the ground-truth
+l1_gt_usm: True
+percep_gt_usm: True
+gan_gt_usm: False
+
+# the first degradation process
+resize_prob: [0.2, 0.7, 0.1] # up, down, keep
+resize_range: [0.15, 1.5]
+gaussian_noise_prob: 0.5
+noise_range: [1, 30]
+poisson_scale_range: [0.05, 3]
+gray_noise_prob: 0.4
+jpeg_range: [30, 95]
+
+# the second degradation process
+second_blur_prob: 0.8
+resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
+resize_range2: [0.3, 1.2]
+gaussian_noise_prob2: 0.5
+noise_range2: [1, 25]
+poisson_scale_range2: [0.05, 2.5]
+gray_noise_prob2: 0.4
+jpeg_range2: [30, 95]
+
+gt_size: 256
+queue_size: 180
+
+# dataset and data loader settings
+datasets:
+ train:
+ name: DF2K+OST
+ type: RealESRGANDataset
+ dataroot_gt: datasets/DF2K
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
+ io_backend:
+ type: disk
+
+ blur_kernel_size: 21
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob: 0.1
+ blur_sigma: [0.2, 3]
+ betag_range: [0.5, 4]
+ betap_range: [1, 2]
+
+ blur_kernel_size2: 21
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob2: 0.1
+ blur_sigma2: [0.2, 1.5]
+ betag_range2: [0.5, 4]
+ betap_range2: [1, 2]
+
+ final_sinc_prob: 0.8
+
+ gt_size: 256
+ use_hflip: True
+ use_rot: False
+
+ # data loader
+ use_shuffle: true
+ num_worker_per_gpu: 5
+ batch_size_per_gpu: 12
+ dataset_enlarge_ratio: 1
+ prefetch_mode: ~
+
+ # Uncomment these for validation
+ # val:
+ # name: validation
+ # type: PairedImageDataset
+ # dataroot_gt: path_to_gt
+ # dataroot_lq: path_to_lq
+ # io_backend:
+ # type: disk
+
+# network structures
+network_g:
+ type: RRDBNet
+ num_in_ch: 3
+ num_out_ch: 3
+ num_feat: 64
+ num_block: 23
+ num_grow_ch: 32
+
+network_d:
+ type: UNetDiscriminatorSN
+ num_in_ch: 3
+ num_feat: 64
+ skip_connection: True
+
+# path
+path:
+ # use the pre-trained Real-ESRNet model
+ pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
+ param_key_g: params_ema
+ strict_load_g: true
+ pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
+ param_key_d: params
+ strict_load_d: true
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0.999
+ optim_g:
+ type: Adam
+ lr: !!float 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+ optim_d:
+ type: Adam
+ lr: !!float 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [400000]
+ gamma: 0.5
+
+ total_iter: 400000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: 1.0
+ reduction: mean
+ # perceptual loss (content and style losses)
+ perceptual_opt:
+ type: PerceptualLoss
+ layer_weights:
+ # before relu
+ 'conv1_2': 0.1
+ 'conv2_2': 0.1
+ 'conv3_4': 1
+ 'conv4_4': 1
+ 'conv5_4': 1
+ vgg_type: vgg19
+ use_input_norm: true
+ perceptual_weight: !!float 1.0
+ style_weight: 0
+ range_norm: false
+ criterion: l1
+ # gan loss
+ gan_opt:
+ type: GANLoss
+ gan_type: vanilla
+ real_label_val: 1.0
+ fake_label_val: 0.0
+ loss_weight: !!float 1e-1
+
+ net_d_iters: 1
+ net_d_init_iters: 0
+
+# Uncomment these for validation
+# validation settings
+# val:
+# val_freq: !!float 5e3
+# save_img: True
+
+# metrics:
+# psnr: # metric name
+# type: calculate_psnr
+# crop_border: 4
+# test_y_channel: false
+
+# logging settings
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+
+# dist training settings
+dist_params:
+ backend: nccl
+ port: 29500
diff --git a/RealESRGANv030/options/finetune_realesrgan_x4plus_pairdata.yml b/RealESRGANv030/options/finetune_realesrgan_x4plus_pairdata.yml
new file mode 100644
index 0000000000000000000000000000000000000000..db45d4d275facc1191caa87d2d8618c30624477a
--- /dev/null
+++ b/RealESRGANv030/options/finetune_realesrgan_x4plus_pairdata.yml
@@ -0,0 +1,150 @@
+# general settings
+name: finetune_RealESRGANx4plus_400k_pairdata
+model_type: RealESRGANModel
+scale: 4
+num_gpu: auto
+manual_seed: 0
+
+# USM the ground-truth
+l1_gt_usm: True
+percep_gt_usm: True
+gan_gt_usm: False
+
+high_order_degradation: False # do not use the high-order degradation generation process
+
+# dataset and data loader settings
+datasets:
+ train:
+ name: DIV2K
+ type: RealESRGANPairedDataset
+ dataroot_gt: datasets/DF2K
+ dataroot_lq: datasets/DF2K
+ meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
+ io_backend:
+ type: disk
+
+ gt_size: 256
+ use_hflip: True
+ use_rot: False
+
+ # data loader
+ use_shuffle: true
+ num_worker_per_gpu: 5
+ batch_size_per_gpu: 12
+ dataset_enlarge_ratio: 1
+ prefetch_mode: ~
+
+ # Uncomment these for validation
+ # val:
+ # name: validation
+ # type: PairedImageDataset
+ # dataroot_gt: path_to_gt
+ # dataroot_lq: path_to_lq
+ # io_backend:
+ # type: disk
+
+# network structures
+network_g:
+ type: RRDBNet
+ num_in_ch: 3
+ num_out_ch: 3
+ num_feat: 64
+ num_block: 23
+ num_grow_ch: 32
+
+network_d:
+ type: UNetDiscriminatorSN
+ num_in_ch: 3
+ num_feat: 64
+ skip_connection: True
+
+# path
+path:
+ # use the pre-trained Real-ESRNet model
+ pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
+ param_key_g: params_ema
+ strict_load_g: true
+ pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
+ param_key_d: params
+ strict_load_d: true
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0.999
+ optim_g:
+ type: Adam
+ lr: !!float 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+ optim_d:
+ type: Adam
+ lr: !!float 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [400000]
+ gamma: 0.5
+
+ total_iter: 400000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: 1.0
+ reduction: mean
+ # perceptual loss (content and style losses)
+ perceptual_opt:
+ type: PerceptualLoss
+ layer_weights:
+ # before relu
+ 'conv1_2': 0.1
+ 'conv2_2': 0.1
+ 'conv3_4': 1
+ 'conv4_4': 1
+ 'conv5_4': 1
+ vgg_type: vgg19
+ use_input_norm: true
+ perceptual_weight: !!float 1.0
+ style_weight: 0
+ range_norm: false
+ criterion: l1
+ # gan loss
+ gan_opt:
+ type: GANLoss
+ gan_type: vanilla
+ real_label_val: 1.0
+ fake_label_val: 0.0
+ loss_weight: !!float 1e-1
+
+ net_d_iters: 1
+ net_d_init_iters: 0
+
+# Uncomment these for validation
+# validation settings
+# val:
+# val_freq: !!float 5e3
+# save_img: True
+
+# metrics:
+# psnr: # metric name
+# type: calculate_psnr
+# crop_border: 4
+# test_y_channel: false
+
+# logging settings
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+
+# dist training settings
+dist_params:
+ backend: nccl
+ port: 29500
diff --git a/RealESRGANv030/options/train_realesrgan_x2plus.yml b/RealESRGANv030/options/train_realesrgan_x2plus.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3c98a0f370def397bdf47ede0fa5f6dd6a4411d5
--- /dev/null
+++ b/RealESRGANv030/options/train_realesrgan_x2plus.yml
@@ -0,0 +1,186 @@
+# general settings
+name: train_RealESRGANx2plus_400k_B12G4
+model_type: RealESRGANModel
+scale: 2
+num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
+manual_seed: 0
+
+# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
+# USM the ground-truth
+l1_gt_usm: True
+percep_gt_usm: True
+gan_gt_usm: False
+
+# the first degradation process
+resize_prob: [0.2, 0.7, 0.1] # up, down, keep
+resize_range: [0.15, 1.5]
+gaussian_noise_prob: 0.5
+noise_range: [1, 30]
+poisson_scale_range: [0.05, 3]
+gray_noise_prob: 0.4
+jpeg_range: [30, 95]
+
+# the second degradation process
+second_blur_prob: 0.8
+resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
+resize_range2: [0.3, 1.2]
+gaussian_noise_prob2: 0.5
+noise_range2: [1, 25]
+poisson_scale_range2: [0.05, 2.5]
+gray_noise_prob2: 0.4
+jpeg_range2: [30, 95]
+
+gt_size: 256
+queue_size: 180
+
+# dataset and data loader settings
+datasets:
+ train:
+ name: DF2K+OST
+ type: RealESRGANDataset
+ dataroot_gt: datasets/DF2K
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
+ io_backend:
+ type: disk
+
+ blur_kernel_size: 21
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob: 0.1
+ blur_sigma: [0.2, 3]
+ betag_range: [0.5, 4]
+ betap_range: [1, 2]
+
+ blur_kernel_size2: 21
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob2: 0.1
+ blur_sigma2: [0.2, 1.5]
+ betag_range2: [0.5, 4]
+ betap_range2: [1, 2]
+
+ final_sinc_prob: 0.8
+
+ gt_size: 256
+ use_hflip: True
+ use_rot: False
+
+ # data loader
+ use_shuffle: true
+ num_worker_per_gpu: 5
+ batch_size_per_gpu: 12
+ dataset_enlarge_ratio: 1
+ prefetch_mode: ~
+
+ # Uncomment these for validation
+ # val:
+ # name: validation
+ # type: PairedImageDataset
+ # dataroot_gt: path_to_gt
+ # dataroot_lq: path_to_lq
+ # io_backend:
+ # type: disk
+
+# network structures
+network_g:
+ type: RRDBNet
+ num_in_ch: 3
+ num_out_ch: 3
+ num_feat: 64
+ num_block: 23
+ num_grow_ch: 32
+ scale: 2
+
+network_d:
+ type: UNetDiscriminatorSN
+ num_in_ch: 3
+ num_feat: 64
+ skip_connection: True
+
+# path
+path:
+ # use the pre-trained Real-ESRNet model
+ pretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pth
+ param_key_g: params_ema
+ strict_load_g: true
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0.999
+ optim_g:
+ type: Adam
+ lr: !!float 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+ optim_d:
+ type: Adam
+ lr: !!float 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [400000]
+ gamma: 0.5
+
+ total_iter: 400000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: 1.0
+ reduction: mean
+ # perceptual loss (content and style losses)
+ perceptual_opt:
+ type: PerceptualLoss
+ layer_weights:
+ # before relu
+ 'conv1_2': 0.1
+ 'conv2_2': 0.1
+ 'conv3_4': 1
+ 'conv4_4': 1
+ 'conv5_4': 1
+ vgg_type: vgg19
+ use_input_norm: true
+ perceptual_weight: !!float 1.0
+ style_weight: 0
+ range_norm: false
+ criterion: l1
+ # gan loss
+ gan_opt:
+ type: GANLoss
+ gan_type: vanilla
+ real_label_val: 1.0
+ fake_label_val: 0.0
+ loss_weight: !!float 1e-1
+
+ net_d_iters: 1
+ net_d_init_iters: 0
+
+# Uncomment these for validation
+# validation settings
+# val:
+# val_freq: !!float 5e3
+# save_img: True
+
+# metrics:
+# psnr: # metric name
+# type: calculate_psnr
+# crop_border: 4
+# test_y_channel: false
+
+# logging settings
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+
+# dist training settings
+dist_params:
+ backend: nccl
+ port: 29500
diff --git a/RealESRGANv030/options/train_realesrgan_x4plus.yml b/RealESRGANv030/options/train_realesrgan_x4plus.yml
new file mode 100644
index 0000000000000000000000000000000000000000..763199a35fa0135713b4a87b00c25f63062ac8aa
--- /dev/null
+++ b/RealESRGANv030/options/train_realesrgan_x4plus.yml
@@ -0,0 +1,185 @@
+# general settings
+name: train_RealESRGANx4plus_400k_B12G4
+model_type: RealESRGANModel
+scale: 4
+num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
+manual_seed: 0
+
+# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
+# USM the ground-truth
+l1_gt_usm: True
+percep_gt_usm: True
+gan_gt_usm: False
+
+# the first degradation process
+resize_prob: [0.2, 0.7, 0.1] # up, down, keep
+resize_range: [0.15, 1.5]
+gaussian_noise_prob: 0.5
+noise_range: [1, 30]
+poisson_scale_range: [0.05, 3]
+gray_noise_prob: 0.4
+jpeg_range: [30, 95]
+
+# the second degradation process
+second_blur_prob: 0.8
+resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
+resize_range2: [0.3, 1.2]
+gaussian_noise_prob2: 0.5
+noise_range2: [1, 25]
+poisson_scale_range2: [0.05, 2.5]
+gray_noise_prob2: 0.4
+jpeg_range2: [30, 95]
+
+gt_size: 256
+queue_size: 180
+
+# dataset and data loader settings
+datasets:
+ train:
+ name: DF2K+OST
+ type: RealESRGANDataset
+ dataroot_gt: datasets/DF2K
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
+ io_backend:
+ type: disk
+
+ blur_kernel_size: 21
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob: 0.1
+ blur_sigma: [0.2, 3]
+ betag_range: [0.5, 4]
+ betap_range: [1, 2]
+
+ blur_kernel_size2: 21
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob2: 0.1
+ blur_sigma2: [0.2, 1.5]
+ betag_range2: [0.5, 4]
+ betap_range2: [1, 2]
+
+ final_sinc_prob: 0.8
+
+ gt_size: 256
+ use_hflip: True
+ use_rot: False
+
+ # data loader
+ use_shuffle: true
+ num_worker_per_gpu: 5
+ batch_size_per_gpu: 12
+ dataset_enlarge_ratio: 1
+ prefetch_mode: ~
+
+ # Uncomment these for validation
+ # val:
+ # name: validation
+ # type: PairedImageDataset
+ # dataroot_gt: path_to_gt
+ # dataroot_lq: path_to_lq
+ # io_backend:
+ # type: disk
+
+# network structures
+network_g:
+ type: RRDBNet
+ num_in_ch: 3
+ num_out_ch: 3
+ num_feat: 64
+ num_block: 23
+ num_grow_ch: 32
+
+network_d:
+ type: UNetDiscriminatorSN
+ num_in_ch: 3
+ num_feat: 64
+ skip_connection: True
+
+# path
+path:
+ # use the pre-trained Real-ESRNet model
+ pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
+ param_key_g: params_ema
+ strict_load_g: true
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0.999
+ optim_g:
+ type: Adam
+ lr: !!float 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+ optim_d:
+ type: Adam
+ lr: !!float 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [400000]
+ gamma: 0.5
+
+ total_iter: 400000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: 1.0
+ reduction: mean
+ # perceptual loss (content and style losses)
+ perceptual_opt:
+ type: PerceptualLoss
+ layer_weights:
+ # before relu
+ 'conv1_2': 0.1
+ 'conv2_2': 0.1
+ 'conv3_4': 1
+ 'conv4_4': 1
+ 'conv5_4': 1
+ vgg_type: vgg19
+ use_input_norm: true
+ perceptual_weight: !!float 1.0
+ style_weight: 0
+ range_norm: false
+ criterion: l1
+ # gan loss
+ gan_opt:
+ type: GANLoss
+ gan_type: vanilla
+ real_label_val: 1.0
+ fake_label_val: 0.0
+ loss_weight: !!float 1e-1
+
+ net_d_iters: 1
+ net_d_init_iters: 0
+
+# Uncomment these for validation
+# validation settings
+# val:
+# val_freq: !!float 5e3
+# save_img: True
+
+# metrics:
+# psnr: # metric name
+# type: calculate_psnr
+# crop_border: 4
+# test_y_channel: false
+
+# logging settings
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+
+# dist training settings
+dist_params:
+ backend: nccl
+ port: 29500
diff --git a/RealESRGANv030/options/train_realesrnet_x2plus.yml b/RealESRGANv030/options/train_realesrnet_x2plus.yml
new file mode 100644
index 0000000000000000000000000000000000000000..81ee9ef16817eaf17cf993cea1a4a8d51815d96c
--- /dev/null
+++ b/RealESRGANv030/options/train_realesrnet_x2plus.yml
@@ -0,0 +1,145 @@
+# general settings
+name: train_RealESRNetx2plus_1000k_B12G4
+model_type: RealESRNetModel
+scale: 2
+num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
+manual_seed: 0
+
+# ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
+gt_usm: True # USM the ground-truth
+
+# the first degradation process
+resize_prob: [0.2, 0.7, 0.1] # up, down, keep
+resize_range: [0.15, 1.5]
+gaussian_noise_prob: 0.5
+noise_range: [1, 30]
+poisson_scale_range: [0.05, 3]
+gray_noise_prob: 0.4
+jpeg_range: [30, 95]
+
+# the second degradation process
+second_blur_prob: 0.8
+resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
+resize_range2: [0.3, 1.2]
+gaussian_noise_prob2: 0.5
+noise_range2: [1, 25]
+poisson_scale_range2: [0.05, 2.5]
+gray_noise_prob2: 0.4
+jpeg_range2: [30, 95]
+
+gt_size: 256
+queue_size: 180
+
+# dataset and data loader settings
+datasets:
+ train:
+ name: DF2K+OST
+ type: RealESRGANDataset
+ dataroot_gt: datasets/DF2K
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
+ io_backend:
+ type: disk
+
+ blur_kernel_size: 21
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob: 0.1
+ blur_sigma: [0.2, 3]
+ betag_range: [0.5, 4]
+ betap_range: [1, 2]
+
+ blur_kernel_size2: 21
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob2: 0.1
+ blur_sigma2: [0.2, 1.5]
+ betag_range2: [0.5, 4]
+ betap_range2: [1, 2]
+
+ final_sinc_prob: 0.8
+
+ gt_size: 256
+ use_hflip: True
+ use_rot: False
+
+ # data loader
+ use_shuffle: true
+ num_worker_per_gpu: 5
+ batch_size_per_gpu: 12
+ dataset_enlarge_ratio: 1
+ prefetch_mode: ~
+
+ # Uncomment these for validation
+ # val:
+ # name: validation
+ # type: PairedImageDataset
+ # dataroot_gt: path_to_gt
+ # dataroot_lq: path_to_lq
+ # io_backend:
+ # type: disk
+
+# network structures
+network_g:
+ type: RRDBNet
+ num_in_ch: 3
+ num_out_ch: 3
+ num_feat: 64
+ num_block: 23
+ num_grow_ch: 32
+ scale: 2
+
+# path
+path:
+ pretrain_network_g: experiments/pretrained_models/RealESRGAN_x4plus.pth
+ param_key_g: params_ema
+ strict_load_g: False
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0.999
+ optim_g:
+ type: Adam
+ lr: !!float 2e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [1000000]
+ gamma: 0.5
+
+ total_iter: 1000000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: 1.0
+ reduction: mean
+
+# Uncomment these for validation
+# validation settings
+# val:
+# val_freq: !!float 5e3
+# save_img: True
+
+# metrics:
+# psnr: # metric name
+# type: calculate_psnr
+# crop_border: 4
+# test_y_channel: false
+
+# logging settings
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+
+# dist training settings
+dist_params:
+ backend: nccl
+ port: 29500
diff --git a/RealESRGANv030/options/train_realesrnet_x4plus.yml b/RealESRGANv030/options/train_realesrnet_x4plus.yml
new file mode 100644
index 0000000000000000000000000000000000000000..45670ed824ae0c697a395049b089e50364292dfc
--- /dev/null
+++ b/RealESRGANv030/options/train_realesrnet_x4plus.yml
@@ -0,0 +1,144 @@
+# general settings
+name: train_RealESRNetx4plus_1000k_B12G4
+model_type: RealESRNetModel
+scale: 4
+num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
+manual_seed: 0
+
+# ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
+gt_usm: True # USM the ground-truth
+
+# the first degradation process
+resize_prob: [0.2, 0.7, 0.1] # up, down, keep
+resize_range: [0.15, 1.5]
+gaussian_noise_prob: 0.5
+noise_range: [1, 30]
+poisson_scale_range: [0.05, 3]
+gray_noise_prob: 0.4
+jpeg_range: [30, 95]
+
+# the second degradation process
+second_blur_prob: 0.8
+resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
+resize_range2: [0.3, 1.2]
+gaussian_noise_prob2: 0.5
+noise_range2: [1, 25]
+poisson_scale_range2: [0.05, 2.5]
+gray_noise_prob2: 0.4
+jpeg_range2: [30, 95]
+
+gt_size: 256
+queue_size: 180
+
+# dataset and data loader settings
+datasets:
+ train:
+ name: DF2K+OST
+ type: RealESRGANDataset
+ dataroot_gt: datasets/DF2K
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
+ io_backend:
+ type: disk
+
+ blur_kernel_size: 21
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob: 0.1
+ blur_sigma: [0.2, 3]
+ betag_range: [0.5, 4]
+ betap_range: [1, 2]
+
+ blur_kernel_size2: 21
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob2: 0.1
+ blur_sigma2: [0.2, 1.5]
+ betag_range2: [0.5, 4]
+ betap_range2: [1, 2]
+
+ final_sinc_prob: 0.8
+
+ gt_size: 256
+ use_hflip: True
+ use_rot: False
+
+ # data loader
+ use_shuffle: true
+ num_worker_per_gpu: 5
+ batch_size_per_gpu: 12
+ dataset_enlarge_ratio: 1
+ prefetch_mode: ~
+
+ # Uncomment these for validation
+ # val:
+ # name: validation
+ # type: PairedImageDataset
+ # dataroot_gt: path_to_gt
+ # dataroot_lq: path_to_lq
+ # io_backend:
+ # type: disk
+
+# network structures
+network_g:
+ type: RRDBNet
+ num_in_ch: 3
+ num_out_ch: 3
+ num_feat: 64
+ num_block: 23
+ num_grow_ch: 32
+
+# path
+path:
+ pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth
+ param_key_g: params_ema
+ strict_load_g: true
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0.999
+ optim_g:
+ type: Adam
+ lr: !!float 2e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [1000000]
+ gamma: 0.5
+
+ total_iter: 1000000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: 1.0
+ reduction: mean
+
+# Uncomment these for validation
+# validation settings
+# val:
+# val_freq: !!float 5e3
+# save_img: True
+
+# metrics:
+# psnr: # metric name
+# type: calculate_psnr
+# crop_border: 4
+# test_y_channel: false
+
+# logging settings
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+
+# dist training settings
+dist_params:
+ backend: nccl
+ port: 29500
diff --git a/RealESRGANv030/realesrgan/__init__.py b/RealESRGANv030/realesrgan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2276f1eecded80d1f00ff97b45c66c7a8922b987
--- /dev/null
+++ b/RealESRGANv030/realesrgan/__init__.py
@@ -0,0 +1,6 @@
+# flake8: noqa
+from .archs import *
+from .data import *
+from .models import *
+from .utils import *
+from .version import *
diff --git a/RealESRGANv030/realesrgan/archs/__init__.py b/RealESRGANv030/realesrgan/archs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3fbbf3b78e33b61fd4c33a564a9a617010d90de
--- /dev/null
+++ b/RealESRGANv030/realesrgan/archs/__init__.py
@@ -0,0 +1,10 @@
+import importlib
+from basicsr.utils import scandir
+from os import path as osp
+
+# automatically scan and import arch modules for registry
+# scan all the files that end with '_arch.py' under the archs folder
+arch_folder = osp.dirname(osp.abspath(__file__))
+arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
+# import all the arch modules
+_arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames]
diff --git a/RealESRGANv030/realesrgan/archs/discriminator_arch.py b/RealESRGANv030/realesrgan/archs/discriminator_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b66ab1226d6793de846bc9828bbe427031a0e2d
--- /dev/null
+++ b/RealESRGANv030/realesrgan/archs/discriminator_arch.py
@@ -0,0 +1,67 @@
+from basicsr.utils.registry import ARCH_REGISTRY
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm
+
+
+@ARCH_REGISTRY.register()
+class UNetDiscriminatorSN(nn.Module):
+ """Defines a U-Net discriminator with spectral normalization (SN)
+
+ It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+ Arg:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_feat (int): Channel number of base intermediate features. Default: 64.
+ skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
+ """
+
+ def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
+ super(UNetDiscriminatorSN, self).__init__()
+ self.skip_connection = skip_connection
+ norm = spectral_norm
+ # the first convolution
+ self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
+ # downsample
+ self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
+ self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
+ self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
+ # upsample
+ self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
+ self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
+ self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
+ # extra convolutions
+ self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
+ self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
+ self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
+
+ def forward(self, x):
+ # downsample
+ x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
+ x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
+ x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
+ x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
+
+ # upsample
+ x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
+ x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
+
+ if self.skip_connection:
+ x4 = x4 + x2
+ x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
+ x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
+
+ if self.skip_connection:
+ x5 = x5 + x1
+ x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
+ x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
+
+ if self.skip_connection:
+ x6 = x6 + x0
+
+ # extra convolutions
+ out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
+ out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
+ out = self.conv9(out)
+
+ return out
diff --git a/RealESRGANv030/realesrgan/archs/srvgg_arch.py b/RealESRGANv030/realesrgan/archs/srvgg_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..39460965c9c5ee9cd6eb41c50d33574cb8ba6e50
--- /dev/null
+++ b/RealESRGANv030/realesrgan/archs/srvgg_arch.py
@@ -0,0 +1,69 @@
+from basicsr.utils.registry import ARCH_REGISTRY
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+@ARCH_REGISTRY.register()
+class SRVGGNetCompact(nn.Module):
+ """A compact VGG-style network structure for super-resolution.
+
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
+ conducted on the HR feature space.
+
+ Args:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_out_ch (int): Channel number of outputs. Default: 3.
+ num_feat (int): Channel number of intermediate features. Default: 64.
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
+ upscale (int): Upsampling factor. Default: 4.
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
+ super(SRVGGNetCompact, self).__init__()
+ self.num_in_ch = num_in_ch
+ self.num_out_ch = num_out_ch
+ self.num_feat = num_feat
+ self.num_conv = num_conv
+ self.upscale = upscale
+ self.act_type = act_type
+
+ self.body = nn.ModuleList()
+ # the first conv
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
+ # the first activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the body structure
+ for _ in range(num_conv):
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
+ # activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the last conv
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
+ # upsample
+ self.upsampler = nn.PixelShuffle(upscale)
+
+ def forward(self, x):
+ out = x
+ for i in range(0, len(self.body)):
+ out = self.body[i](out)
+
+ out = self.upsampler(out)
+ # add the nearest upsampled image, so that the network learns the residual
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
+ out += base
+ return out
diff --git a/RealESRGANv030/realesrgan/data/__init__.py b/RealESRGANv030/realesrgan/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f8fdd1aa47c12de9687c578094303eb7369246
--- /dev/null
+++ b/RealESRGANv030/realesrgan/data/__init__.py
@@ -0,0 +1,10 @@
+import importlib
+from basicsr.utils import scandir
+from os import path as osp
+
+# automatically scan and import dataset modules for registry
+# scan all the files that end with '_dataset.py' under the data folder
+data_folder = osp.dirname(osp.abspath(__file__))
+dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
+# import all the dataset modules
+_dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames]
diff --git a/RealESRGANv030/realesrgan/data/realesrgan_dataset.py b/RealESRGANv030/realesrgan/data/realesrgan_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cf2d9e6583a6789b771679734ce55bb8a22e628
--- /dev/null
+++ b/RealESRGANv030/realesrgan/data/realesrgan_dataset.py
@@ -0,0 +1,192 @@
+import cv2
+import math
+import numpy as np
+import os
+import os.path as osp
+import random
+import time
+import torch
+from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
+from basicsr.data.transforms import augment
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+from torch.utils import data as data
+
+
+@DATASET_REGISTRY.register()
+class RealESRGANDataset(data.Dataset):
+ """Dataset used for Real-ESRGAN model:
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+ It loads gt (Ground-Truth) images, and augments them.
+ It also generates blur kernels and sinc kernels for generating low-quality images.
+ Note that the low-quality images are processed in tensors on GPUS for faster processing.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ meta_info (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+ Please see more options in the codes.
+ """
+
+ def __init__(self, opt):
+ super(RealESRGANDataset, self).__init__()
+ self.opt = opt
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.gt_folder = opt['dataroot_gt']
+
+ # file client (lmdb io backend)
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = [self.gt_folder]
+ self.io_backend_opt['client_keys'] = ['gt']
+ if not self.gt_folder.endswith('.lmdb'):
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
+ self.paths = [line.split('.')[0] for line in fin]
+ else:
+ # disk backend with meta_info
+ # Each line in the meta_info describes the relative path to an image
+ with open(self.opt['meta_info']) as fin:
+ paths = [line.strip().split(' ')[0] for line in fin]
+ self.paths = [os.path.join(self.gt_folder, v) for v in paths]
+
+ # blur settings for the first degradation
+ self.blur_kernel_size = opt['blur_kernel_size']
+ self.kernel_list = opt['kernel_list']
+ self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
+ self.blur_sigma = opt['blur_sigma']
+ self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
+ self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
+ self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
+
+ # blur settings for the second degradation
+ self.blur_kernel_size2 = opt['blur_kernel_size2']
+ self.kernel_list2 = opt['kernel_list2']
+ self.kernel_prob2 = opt['kernel_prob2']
+ self.blur_sigma2 = opt['blur_sigma2']
+ self.betag_range2 = opt['betag_range2']
+ self.betap_range2 = opt['betap_range2']
+ self.sinc_prob2 = opt['sinc_prob2']
+
+ # a final sinc filter
+ self.final_sinc_prob = opt['final_sinc_prob']
+
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
+ # TODO: kernel range is now hard-coded, should be in the configure file
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
+ self.pulse_tensor[10, 10] = 1
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # -------------------------------- Load gt images -------------------------------- #
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
+ gt_path = self.paths[index]
+ # avoid errors caused by high latency in reading files
+ retry = 3
+ while retry > 0:
+ try:
+ img_bytes = self.file_client.get(gt_path, 'gt')
+ except (IOError, OSError) as e:
+ logger = get_root_logger()
+ logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
+ # change another file to read
+ index = random.randint(0, self.__len__())
+ gt_path = self.paths[index]
+ time.sleep(1) # sleep 1s for occasional server congestion
+ else:
+ break
+ finally:
+ retry -= 1
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # -------------------- Do augmentation for training: flip, rotation -------------------- #
+ img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
+
+ # crop or pad to 400
+ # TODO: 400 is hard-coded. You may change it accordingly
+ h, w = img_gt.shape[0:2]
+ crop_pad_size = 400
+ # pad
+ if h < crop_pad_size or w < crop_pad_size:
+ pad_h = max(0, crop_pad_size - h)
+ pad_w = max(0, crop_pad_size - w)
+ img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
+ # crop
+ if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
+ h, w = img_gt.shape[0:2]
+ # randomly choose top and left coordinates
+ top = random.randint(0, h - crop_pad_size)
+ left = random.randint(0, w - crop_pad_size)
+ img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
+
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
+ kernel_size = random.choice(self.kernel_range)
+ if np.random.uniform() < self.opt['sinc_prob']:
+ # this sinc filter setting is for kernels ranging from [7, 21]
+ if kernel_size < 13:
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ else:
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+ else:
+ kernel = random_mixed_kernels(
+ self.kernel_list,
+ self.kernel_prob,
+ kernel_size,
+ self.blur_sigma,
+ self.blur_sigma, [-math.pi, math.pi],
+ self.betag_range,
+ self.betap_range,
+ noise_range=None)
+ # pad kernel
+ pad_size = (21 - kernel_size) // 2
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
+
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
+ kernel_size = random.choice(self.kernel_range)
+ if np.random.uniform() < self.opt['sinc_prob2']:
+ if kernel_size < 13:
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ else:
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+ else:
+ kernel2 = random_mixed_kernels(
+ self.kernel_list2,
+ self.kernel_prob2,
+ kernel_size,
+ self.blur_sigma2,
+ self.blur_sigma2, [-math.pi, math.pi],
+ self.betag_range2,
+ self.betap_range2,
+ noise_range=None)
+
+ # pad kernel
+ pad_size = (21 - kernel_size) // 2
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
+
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
+ if np.random.uniform() < self.opt['final_sinc_prob']:
+ kernel_size = random.choice(self.kernel_range)
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
+ else:
+ sinc_kernel = self.pulse_tensor
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
+ kernel = torch.FloatTensor(kernel)
+ kernel2 = torch.FloatTensor(kernel2)
+
+ return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
+ return return_d
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/RealESRGANv030/realesrgan/data/realesrgan_paired_dataset.py b/RealESRGANv030/realesrgan/data/realesrgan_paired_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..386c8d72496245dae8df033c2ebbd76b41ff45f1
--- /dev/null
+++ b/RealESRGANv030/realesrgan/data/realesrgan_paired_dataset.py
@@ -0,0 +1,108 @@
+import os
+from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
+from basicsr.data.transforms import augment, paired_random_crop
+from basicsr.utils import FileClient, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+
+@DATASET_REGISTRY.register()
+class RealESRGANPairedDataset(data.Dataset):
+ """Paired image dataset for image restoration.
+
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
+
+ There are three modes:
+ 1. 'lmdb': Use lmdb files.
+ If opt['io_backend'] == lmdb.
+ 2. 'meta_info': Use meta information file to generate paths.
+ If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
+ 3. 'folder': Scan folders to generate paths.
+ The rest.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ meta_info (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
+ Default: '{}'.
+ gt_size (int): Cropped patched size for gt patches.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h
+ and w for implementation).
+
+ scale (bool): Scale, which will be added automatically.
+ phase (str): 'train' or 'val'.
+ """
+
+ def __init__(self, opt):
+ super(RealESRGANPairedDataset, self).__init__()
+ self.opt = opt
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ # mean and std for normalizing the input images
+ self.mean = opt['mean'] if 'mean' in opt else None
+ self.std = opt['std'] if 'std' in opt else None
+
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
+ self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
+
+ # file client (lmdb io backend)
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
+ elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
+ # disk backend with meta_info
+ # Each line in the meta_info describes the relative path to an image
+ with open(self.opt['meta_info']) as fin:
+ paths = [line.strip() for line in fin]
+ self.paths = []
+ for path in paths:
+ gt_path, lq_path = path.split(', ')
+ gt_path = os.path.join(self.gt_folder, gt_path)
+ lq_path = os.path.join(self.lq_folder, lq_path)
+ self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
+ else:
+ # disk backend
+ # it will scan the whole folder to get meta info
+ # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
+ # image range: [0, 1], float32.
+ gt_path = self.paths[index]['gt_path']
+ img_bytes = self.file_client.get(gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+ lq_path = self.paths[index]['lq_path']
+ img_bytes = self.file_client.get(lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+
+ # augmentation for training
+ if self.opt['phase'] == 'train':
+ gt_size = self.opt['gt_size']
+ # random crop
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
+ # flip, rotation
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
+ # normalize
+ if self.mean is not None or self.std is not None:
+ normalize(img_lq, self.mean, self.std, inplace=True)
+ normalize(img_gt, self.mean, self.std, inplace=True)
+
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/RealESRGANv030/realesrgan/models/__init__.py b/RealESRGANv030/realesrgan/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0be7105dc75d150c49976396724085f678dc0675
--- /dev/null
+++ b/RealESRGANv030/realesrgan/models/__init__.py
@@ -0,0 +1,10 @@
+import importlib
+from basicsr.utils import scandir
+from os import path as osp
+
+# automatically scan and import model modules for registry
+# scan all the files that end with '_model.py' under the model folder
+model_folder = osp.dirname(osp.abspath(__file__))
+model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
+# import all the model modules
+_model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames]
diff --git a/RealESRGANv030/realesrgan/models/realesrgan_model.py b/RealESRGANv030/realesrgan/models/realesrgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c298a09c42433177f90001a0a31d029576072ccd
--- /dev/null
+++ b/RealESRGANv030/realesrgan/models/realesrgan_model.py
@@ -0,0 +1,258 @@
+import numpy as np
+import random
+import torch
+from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
+from basicsr.data.transforms import paired_random_crop
+from basicsr.models.srgan_model import SRGANModel
+from basicsr.utils import DiffJPEG, USMSharp
+from basicsr.utils.img_process_util import filter2D
+from basicsr.utils.registry import MODEL_REGISTRY
+from collections import OrderedDict
+from torch.nn import functional as F
+
+
+@MODEL_REGISTRY.register()
+class RealESRGANModel(SRGANModel):
+ """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+ It mainly performs:
+ 1. randomly synthesize LQ images in GPU tensors
+ 2. optimize the networks with GAN training.
+ """
+
+ def __init__(self, opt):
+ super(RealESRGANModel, self).__init__(opt)
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
+ self.queue_size = opt.get('queue_size', 180)
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self):
+ """It is the training pair pool for increasing the diversity in a batch.
+
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
+ to increase the degradation diversity in a batch.
+ """
+ # initialize
+ b, c, h, w = self.lq.size()
+ if not hasattr(self, 'queue_lr'):
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
+ _, c, h, w = self.gt.size()
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
+ self.queue_ptr = 0
+ if self.queue_ptr == self.queue_size: # the pool is full
+ # do dequeue and enqueue
+ # shuffle
+ idx = torch.randperm(self.queue_size)
+ self.queue_lr = self.queue_lr[idx]
+ self.queue_gt = self.queue_gt[idx]
+ # get first b samples
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
+ # update the queue
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
+
+ self.lq = lq_dequeue
+ self.gt = gt_dequeue
+ else:
+ # only do enqueue
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
+ self.queue_ptr = self.queue_ptr + b
+
+ @torch.no_grad()
+ def feed_data(self, data):
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
+ """
+ if self.is_train and self.opt.get('high_order_degradation', True):
+ # training data synthesis
+ self.gt = data['gt'].to(self.device)
+ self.gt_usm = self.usm_sharpener(self.gt)
+
+ self.kernel1 = data['kernel1'].to(self.device)
+ self.kernel2 = data['kernel2'].to(self.device)
+ self.sinc_kernel = data['sinc_kernel'].to(self.device)
+
+ ori_h, ori_w = self.gt.size()[2:4]
+
+ # ----------------------- The first degradation process ----------------------- #
+ # blur
+ out = filter2D(self.gt_usm, self.kernel1)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
+ # add noise
+ gray_noise_prob = self.opt['gray_noise_prob']
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
+ out = self.jpeger(out, quality=jpeg_p)
+
+ # ----------------------- The second degradation process ----------------------- #
+ # blur
+ if np.random.uniform() < self.opt['second_blur_prob']:
+ out = filter2D(out, self.kernel2)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
+ # add noise
+ gray_noise_prob = self.opt['gray_noise_prob2']
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range2'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+
+ # JPEG compression + the final sinc filter
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+ # as one operation.
+ # We consider two orders:
+ # 1. [resize back + sinc filter] + JPEG compression
+ # 2. JPEG compression + [resize back + sinc filter]
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+ if np.random.uniform() < 0.5:
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+ out = filter2D(out, self.sinc_kernel)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ else:
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+ out = filter2D(out, self.sinc_kernel)
+
+ # clamp and round
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+
+ # random crop
+ gt_size = self.opt['gt_size']
+ (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
+ self.opt['scale'])
+
+ # training pair pool
+ self._dequeue_and_enqueue()
+ # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
+ self.gt_usm = self.usm_sharpener(self.gt)
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
+ else:
+ # for paired training or validation
+ self.lq = data['lq'].to(self.device)
+ if 'gt' in data:
+ self.gt = data['gt'].to(self.device)
+ self.gt_usm = self.usm_sharpener(self.gt)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ # do not use the synthetic process during validation
+ self.is_train = False
+ super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
+ self.is_train = True
+
+ def optimize_parameters(self, current_iter):
+ # usm sharpening
+ l1_gt = self.gt_usm
+ percep_gt = self.gt_usm
+ gan_gt = self.gt_usm
+ if self.opt['l1_gt_usm'] is False:
+ l1_gt = self.gt
+ if self.opt['percep_gt_usm'] is False:
+ percep_gt = self.gt
+ if self.opt['gan_gt_usm'] is False:
+ gan_gt = self.gt
+
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, l1_gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+ # gan loss
+ fake_g_pred = self.net_d(self.output)
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # real
+ real_d_pred = self.net_d(gan_gt)
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
diff --git a/RealESRGANv030/realesrgan/models/realesrnet_model.py b/RealESRGANv030/realesrgan/models/realesrnet_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d11668f3712bffcd062c57db14d22ca3a0e1e59d
--- /dev/null
+++ b/RealESRGANv030/realesrgan/models/realesrnet_model.py
@@ -0,0 +1,188 @@
+import numpy as np
+import random
+import torch
+from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
+from basicsr.data.transforms import paired_random_crop
+from basicsr.models.sr_model import SRModel
+from basicsr.utils import DiffJPEG, USMSharp
+from basicsr.utils.img_process_util import filter2D
+from basicsr.utils.registry import MODEL_REGISTRY
+from torch.nn import functional as F
+
+
+@MODEL_REGISTRY.register()
+class RealESRNetModel(SRModel):
+ """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+ It is trained without GAN losses.
+ It mainly performs:
+ 1. randomly synthesize LQ images in GPU tensors
+ 2. optimize the networks with GAN training.
+ """
+
+ def __init__(self, opt):
+ super(RealESRNetModel, self).__init__(opt)
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
+ self.queue_size = opt.get('queue_size', 180)
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self):
+ """It is the training pair pool for increasing the diversity in a batch.
+
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
+ to increase the degradation diversity in a batch.
+ """
+ # initialize
+ b, c, h, w = self.lq.size()
+ if not hasattr(self, 'queue_lr'):
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
+ _, c, h, w = self.gt.size()
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
+ self.queue_ptr = 0
+ if self.queue_ptr == self.queue_size: # the pool is full
+ # do dequeue and enqueue
+ # shuffle
+ idx = torch.randperm(self.queue_size)
+ self.queue_lr = self.queue_lr[idx]
+ self.queue_gt = self.queue_gt[idx]
+ # get first b samples
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
+ # update the queue
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
+
+ self.lq = lq_dequeue
+ self.gt = gt_dequeue
+ else:
+ # only do enqueue
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
+ self.queue_ptr = self.queue_ptr + b
+
+ @torch.no_grad()
+ def feed_data(self, data):
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
+ """
+ if self.is_train and self.opt.get('high_order_degradation', True):
+ # training data synthesis
+ self.gt = data['gt'].to(self.device)
+ # USM sharpen the GT images
+ if self.opt['gt_usm'] is True:
+ self.gt = self.usm_sharpener(self.gt)
+
+ self.kernel1 = data['kernel1'].to(self.device)
+ self.kernel2 = data['kernel2'].to(self.device)
+ self.sinc_kernel = data['sinc_kernel'].to(self.device)
+
+ ori_h, ori_w = self.gt.size()[2:4]
+
+ # ----------------------- The first degradation process ----------------------- #
+ # blur
+ out = filter2D(self.gt, self.kernel1)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
+ # add noise
+ gray_noise_prob = self.opt['gray_noise_prob']
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
+ out = self.jpeger(out, quality=jpeg_p)
+
+ # ----------------------- The second degradation process ----------------------- #
+ # blur
+ if np.random.uniform() < self.opt['second_blur_prob']:
+ out = filter2D(out, self.kernel2)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
+ # add noise
+ gray_noise_prob = self.opt['gray_noise_prob2']
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range2'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+
+ # JPEG compression + the final sinc filter
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+ # as one operation.
+ # We consider two orders:
+ # 1. [resize back + sinc filter] + JPEG compression
+ # 2. JPEG compression + [resize back + sinc filter]
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+ if np.random.uniform() < 0.5:
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+ out = filter2D(out, self.sinc_kernel)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ else:
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+ out = filter2D(out, self.sinc_kernel)
+
+ # clamp and round
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+
+ # random crop
+ gt_size = self.opt['gt_size']
+ self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
+
+ # training pair pool
+ self._dequeue_and_enqueue()
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
+ else:
+ # for paired training or validation
+ self.lq = data['lq'].to(self.device)
+ if 'gt' in data:
+ self.gt = data['gt'].to(self.device)
+ self.gt_usm = self.usm_sharpener(self.gt)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ # do not use the synthetic process during validation
+ self.is_train = False
+ super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
+ self.is_train = True
diff --git a/RealESRGANv030/realesrgan/train.py b/RealESRGANv030/realesrgan/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a9cec9ed80d9f362984779548dcec921a636a04
--- /dev/null
+++ b/RealESRGANv030/realesrgan/train.py
@@ -0,0 +1,11 @@
+# flake8: noqa
+import os.path as osp
+from basicsr.train import train_pipeline
+
+import realesrgan.archs
+import realesrgan.data
+import realesrgan.models
+
+if __name__ == '__main__':
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
+ train_pipeline(root_path)
diff --git a/RealESRGANv030/realesrgan/utils.py b/RealESRGANv030/realesrgan/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e5232d61e93807f22b052499a733cd348a61a0
--- /dev/null
+++ b/RealESRGANv030/realesrgan/utils.py
@@ -0,0 +1,313 @@
+import cv2
+import math
+import numpy as np
+import os
+import queue
+import threading
+import torch
+from basicsr.utils.download_util import load_file_from_url
+from torch.nn import functional as F
+
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+
+class RealESRGANer():
+ """A helper class for upsampling images with RealESRGAN.
+
+ Args:
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
+ model (nn.Module): The defined network. Default: None.
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
+ 0 denotes for do not use tile. Default: 0.
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
+ half (float): Whether to use half precision during inference. Default: False.
+ """
+
+ def __init__(self,
+ scale,
+ model_path,
+ dni_weight=None,
+ model=None,
+ tile=0,
+ tile_pad=10,
+ pre_pad=10,
+ half=False,
+ device=None,
+ gpu_id=None):
+ self.scale = scale
+ self.tile_size = tile
+ self.tile_pad = tile_pad
+ self.pre_pad = pre_pad
+ self.mod_scale = None
+ self.half = half
+
+ # initialize model
+ if gpu_id:
+ self.device = torch.device(
+ f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
+ else:
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
+
+ if isinstance(model_path, list):
+ # dni
+ assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
+ loadnet = self.dni(model_path[0], model_path[1], dni_weight)
+ else:
+ # if the model_path starts with https, it will first download models to the folder: weights
+ if model_path.startswith('https://'):
+ model_path = load_file_from_url(
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
+ loadnet = torch.load(model_path, map_location=torch.device('cpu'))
+
+ # prefer to use params_ema
+ if 'params_ema' in loadnet:
+ keyname = 'params_ema'
+ else:
+ keyname = 'params'
+ model.load_state_dict(loadnet[keyname], strict=True)
+
+ model.eval()
+ self.model = model.to(self.device)
+ if self.half:
+ self.model = self.model.half()
+
+ def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
+ """Deep network interpolation.
+
+ ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
+ """
+ net_a = torch.load(net_a, map_location=torch.device(loc))
+ net_b = torch.load(net_b, map_location=torch.device(loc))
+ for k, v_a in net_a[key].items():
+ net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
+ return net_a
+
+ def pre_process(self, img):
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
+ """
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
+ self.img = img.unsqueeze(0).to(self.device)
+ if self.half:
+ self.img = self.img.half()
+
+ # pre_pad
+ if self.pre_pad != 0:
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
+ # mod pad for divisible borders
+ if self.scale == 2:
+ self.mod_scale = 2
+ elif self.scale == 1:
+ self.mod_scale = 4
+ if self.mod_scale is not None:
+ self.mod_pad_h, self.mod_pad_w = 0, 0
+ _, _, h, w = self.img.size()
+ if (h % self.mod_scale != 0):
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
+ if (w % self.mod_scale != 0):
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
+
+ def process(self):
+ # model inference
+ self.output = self.model(self.img)
+
+ def tile_process(self):
+ """It will first crop input images to tiles, and then process each tile.
+ Finally, all the processed tiles are merged into one images.
+
+ Modified from: https://github.com/ata4/esrgan-launcher
+ """
+ batch, channel, height, width = self.img.shape
+ output_height = height * self.scale
+ output_width = width * self.scale
+ output_shape = (batch, channel, output_height, output_width)
+
+ # start with black image
+ self.output = self.img.new_zeros(output_shape)
+ tiles_x = math.ceil(width / self.tile_size)
+ tiles_y = math.ceil(height / self.tile_size)
+
+ # loop over all tiles
+ for y in range(tiles_y):
+ for x in range(tiles_x):
+ # extract tile from input image
+ ofs_x = x * self.tile_size
+ ofs_y = y * self.tile_size
+ # input tile area on total image
+ input_start_x = ofs_x
+ input_end_x = min(ofs_x + self.tile_size, width)
+ input_start_y = ofs_y
+ input_end_y = min(ofs_y + self.tile_size, height)
+
+ # input tile area on total image with padding
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
+
+ # input tile dimensions
+ input_tile_width = input_end_x - input_start_x
+ input_tile_height = input_end_y - input_start_y
+ tile_idx = y * tiles_x + x + 1
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
+
+ # upscale tile
+ try:
+ with torch.no_grad():
+ output_tile = self.model(input_tile)
+ except RuntimeError as error:
+ print('Error', error)
+ print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
+
+ # output tile area on total image
+ output_start_x = input_start_x * self.scale
+ output_end_x = input_end_x * self.scale
+ output_start_y = input_start_y * self.scale
+ output_end_y = input_end_y * self.scale
+
+ # output tile area without padding
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
+
+ # put tile into output image
+ self.output[:, :, output_start_y:output_end_y,
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
+ output_start_x_tile:output_end_x_tile]
+
+ def post_process(self):
+ # remove extra pad
+ if self.mod_scale is not None:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
+ # remove prepad
+ if self.pre_pad != 0:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
+ return self.output
+
+ @torch.no_grad()
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
+ h_input, w_input = img.shape[0:2]
+ # img: numpy
+ img = img.astype(np.float32)
+ if np.max(img) > 256: # 16-bit image
+ max_range = 65535
+ print('\tInput is a 16-bit image')
+ else:
+ max_range = 255
+ img = img / max_range
+ if len(img.shape) == 2: # gray image
+ img_mode = 'L'
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ elif img.shape[2] == 4: # RGBA image with alpha channel
+ img_mode = 'RGBA'
+ alpha = img[:, :, 3]
+ img = img[:, :, 0:3]
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ if alpha_upsampler == 'realesrgan':
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
+ else:
+ img_mode = 'RGB'
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ # ------------------- process image (without the alpha channel) ------------------- #
+ self.pre_process(img)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_img = self.post_process()
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
+ if img_mode == 'L':
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
+
+ # ------------------- process the alpha channel if necessary ------------------- #
+ if img_mode == 'RGBA':
+ if alpha_upsampler == 'realesrgan':
+ self.pre_process(alpha)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_alpha = self.post_process()
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
+ else: # use the cv2 resize for alpha channel
+ h, w = alpha.shape[0:2]
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
+
+ # merge the alpha channel
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
+ output_img[:, :, 3] = output_alpha
+
+ # ------------------------------ return ------------------------------ #
+ if max_range == 65535: # 16-bit image
+ output = (output_img * 65535.0).round().astype(np.uint16)
+ else:
+ output = (output_img * 255.0).round().astype(np.uint8)
+
+ if outscale is not None and outscale != float(self.scale):
+ output = cv2.resize(
+ output, (
+ int(w_input * outscale),
+ int(h_input * outscale),
+ ), interpolation=cv2.INTER_LANCZOS4)
+
+ return output, img_mode
+
+
+class PrefetchReader(threading.Thread):
+ """Prefetch images.
+
+ Args:
+ img_list (list[str]): A image list of image paths to be read.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, img_list, num_prefetch_queue):
+ super().__init__()
+ self.que = queue.Queue(num_prefetch_queue)
+ self.img_list = img_list
+
+ def run(self):
+ for img_path in self.img_list:
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+ self.que.put(img)
+
+ self.que.put(None)
+
+ def __next__(self):
+ next_item = self.que.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class IOConsumer(threading.Thread):
+
+ def __init__(self, opt, que, qid):
+ super().__init__()
+ self._queue = que
+ self.qid = qid
+ self.opt = opt
+
+ def run(self):
+ while True:
+ msg = self._queue.get()
+ if isinstance(msg, str) and msg == 'quit':
+ break
+
+ output = msg['output']
+ save_path = msg['save_path']
+ cv2.imwrite(save_path, output)
+ print(f'IO worker {self.qid} is done.')
diff --git a/RealESRGANv030/realesrgan/version.py b/RealESRGANv030/realesrgan/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..97bb360a8d87975130b9d864bd1746b1d8ac4e04
--- /dev/null
+++ b/RealESRGANv030/realesrgan/version.py
@@ -0,0 +1,5 @@
+# GENERATED VERSION FILE
+# TIME: Mon Jan 9 14:59:42 2023
+__version__ = '0.3.0'
+__gitsha__ = 'unknown'
+version_info = (0, 3, 0)
diff --git a/RealESRGANv030/requirements.txt b/RealESRGANv030/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0c8f3f0e75ea0174a4055bdce8255c541187e4b1
--- /dev/null
+++ b/RealESRGANv030/requirements.txt
@@ -0,0 +1,9 @@
+basicsr>=1.4.2
+facexlib>=0.2.5
+gfpgan>=1.3.5
+numpy
+opencv-python
+Pillow
+torch>=1.7
+torchvision
+tqdm
diff --git a/RealESRGANv030/scripts/extract_subimages.py b/RealESRGANv030/scripts/extract_subimages.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b969ae0d4adff403f2ad362b9afaaaee58e2cef
--- /dev/null
+++ b/RealESRGANv030/scripts/extract_subimages.py
@@ -0,0 +1,135 @@
+import argparse
+import cv2
+import numpy as np
+import os
+import sys
+from basicsr.utils import scandir
+from multiprocessing import Pool
+from os import path as osp
+from tqdm import tqdm
+
+
+def main(args):
+ """A multi-thread tool to crop large images to sub-images for faster IO.
+
+ opt (dict): Configuration dict. It contains:
+ n_thread (int): Thread number.
+ compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size
+ and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
+ input_folder (str): Path to the input folder.
+ save_folder (str): Path to save folder.
+ crop_size (int): Crop size.
+ step (int): Step for overlapped sliding window.
+ thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
+
+ Usage:
+ For each folder, run this script.
+ Typically, there are GT folder and LQ folder to be processed for DIV2K dataset.
+ After process, each sub_folder should have the same number of subimages.
+ Remember to modify opt configurations according to your settings.
+ """
+
+ opt = {}
+ opt['n_thread'] = args.n_thread
+ opt['compression_level'] = args.compression_level
+ opt['input_folder'] = args.input
+ opt['save_folder'] = args.output
+ opt['crop_size'] = args.crop_size
+ opt['step'] = args.step
+ opt['thresh_size'] = args.thresh_size
+ extract_subimages(opt)
+
+
+def extract_subimages(opt):
+ """Crop images to subimages.
+
+ Args:
+ opt (dict): Configuration dict. It contains:
+ input_folder (str): Path to the input folder.
+ save_folder (str): Path to save folder.
+ n_thread (int): Thread number.
+ """
+ input_folder = opt['input_folder']
+ save_folder = opt['save_folder']
+ if not osp.exists(save_folder):
+ os.makedirs(save_folder)
+ print(f'mkdir {save_folder} ...')
+ else:
+ print(f'Folder {save_folder} already exists. Exit.')
+ sys.exit(1)
+
+ # scan all images
+ img_list = list(scandir(input_folder, full_path=True))
+
+ pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
+ pool = Pool(opt['n_thread'])
+ for path in img_list:
+ pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
+ pool.close()
+ pool.join()
+ pbar.close()
+ print('All processes done.')
+
+
+def worker(path, opt):
+ """Worker for each process.
+
+ Args:
+ path (str): Image path.
+ opt (dict): Configuration dict. It contains:
+ crop_size (int): Crop size.
+ step (int): Step for overlapped sliding window.
+ thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
+ save_folder (str): Path to save folder.
+ compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
+
+ Returns:
+ process_info (str): Process information displayed in progress bar.
+ """
+ crop_size = opt['crop_size']
+ step = opt['step']
+ thresh_size = opt['thresh_size']
+ img_name, extension = osp.splitext(osp.basename(path))
+
+ # remove the x2, x3, x4 and x8 in the filename for DIV2K
+ img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
+
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+
+ h, w = img.shape[0:2]
+ h_space = np.arange(0, h - crop_size + 1, step)
+ if h - (h_space[-1] + crop_size) > thresh_size:
+ h_space = np.append(h_space, h - crop_size)
+ w_space = np.arange(0, w - crop_size + 1, step)
+ if w - (w_space[-1] + crop_size) > thresh_size:
+ w_space = np.append(w_space, w - crop_size)
+
+ index = 0
+ for x in h_space:
+ for y in w_space:
+ index += 1
+ cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
+ cropped_img = np.ascontiguousarray(cropped_img)
+ cv2.imwrite(
+ osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
+ [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
+ process_info = f'Processing {img_name} ...'
+ return process_info
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
+ parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_HR_sub', help='Output folder')
+ parser.add_argument('--crop_size', type=int, default=480, help='Crop size')
+ parser.add_argument('--step', type=int, default=240, help='Step for overlapped sliding window')
+ parser.add_argument(
+ '--thresh_size',
+ type=int,
+ default=0,
+ help='Threshold size. Patches whose size is lower than thresh_size will be dropped.')
+ parser.add_argument('--n_thread', type=int, default=20, help='Thread number.')
+ parser.add_argument('--compression_level', type=int, default=3, help='Compression level')
+ args = parser.parse_args()
+
+ main(args)
diff --git a/RealESRGANv030/scripts/generate_meta_info.py b/RealESRGANv030/scripts/generate_meta_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c3b7a37e85f534075c50e6c33d7cca999d8b836
--- /dev/null
+++ b/RealESRGANv030/scripts/generate_meta_info.py
@@ -0,0 +1,58 @@
+import argparse
+import cv2
+import glob
+import os
+
+
+def main(args):
+ txt_file = open(args.meta_info, 'w')
+ for folder, root in zip(args.input, args.root):
+ img_paths = sorted(glob.glob(os.path.join(folder, '*')))
+ for img_path in img_paths:
+ status = True
+ if args.check:
+ # read the image once for check, as some images may have errors
+ try:
+ img = cv2.imread(img_path)
+ except (IOError, OSError) as error:
+ print(f'Read {img_path} error: {error}')
+ status = False
+ if img is None:
+ status = False
+ print(f'Img is None: {img_path}')
+ if status:
+ # get the relative path
+ img_name = os.path.relpath(img_path, root)
+ print(img_name)
+ txt_file.write(f'{img_name}\n')
+
+
+if __name__ == '__main__':
+ """Generate meta info (txt file) for only Ground-Truth images.
+
+ It can also generate meta info from several folders into one txt file.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--input',
+ nargs='+',
+ default=['datasets/DF2K/DF2K_HR', 'datasets/DF2K/DF2K_multiscale'],
+ help='Input folder, can be a list')
+ parser.add_argument(
+ '--root',
+ nargs='+',
+ default=['datasets/DF2K', 'datasets/DF2K'],
+ help='Folder root, should have the length as input folders')
+ parser.add_argument(
+ '--meta_info',
+ type=str,
+ default='datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt',
+ help='txt path for meta info')
+ parser.add_argument('--check', action='store_true', help='Read image to check whether it is ok')
+ args = parser.parse_args()
+
+ assert len(args.input) == len(args.root), ('Input folder and folder root should have the same length, but got '
+ f'{len(args.input)} and {len(args.root)}.')
+ os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
+
+ main(args)
diff --git a/RealESRGANv030/scripts/generate_meta_info_pairdata.py b/RealESRGANv030/scripts/generate_meta_info_pairdata.py
new file mode 100644
index 0000000000000000000000000000000000000000..76dce7e41c803a8055f3627cccb98deb51419b09
--- /dev/null
+++ b/RealESRGANv030/scripts/generate_meta_info_pairdata.py
@@ -0,0 +1,49 @@
+import argparse
+import glob
+import os
+
+
+def main(args):
+ txt_file = open(args.meta_info, 'w')
+ # sca images
+ img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*')))
+ img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*')))
+
+ assert len(img_paths_gt) == len(img_paths_lq), ('GT folder and LQ folder should have the same length, but got '
+ f'{len(img_paths_gt)} and {len(img_paths_lq)}.')
+
+ for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq):
+ # get the relative paths
+ img_name_gt = os.path.relpath(img_path_gt, args.root[0])
+ img_name_lq = os.path.relpath(img_path_lq, args.root[1])
+ print(f'{img_name_gt}, {img_name_lq}')
+ txt_file.write(f'{img_name_gt}, {img_name_lq}\n')
+
+
+if __name__ == '__main__':
+ """This script is used to generate meta info (txt file) for paired images.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--input',
+ nargs='+',
+ default=['datasets/DF2K/DIV2K_train_HR_sub', 'datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub'],
+ help='Input folder, should be [gt_folder, lq_folder]')
+ parser.add_argument('--root', nargs='+', default=[None, None], help='Folder root, will use the ')
+ parser.add_argument(
+ '--meta_info',
+ type=str,
+ default='datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt',
+ help='txt path for meta info')
+ args = parser.parse_args()
+
+ assert len(args.input) == 2, 'Input folder should have two elements: gt folder and lq folder'
+ assert len(args.root) == 2, 'Root path should have two elements: root for gt folder and lq folder'
+ os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
+ for i in range(2):
+ if args.input[i].endswith('/'):
+ args.input[i] = args.input[i][:-1]
+ if args.root[i] is None:
+ args.root[i] = os.path.dirname(args.input[i])
+
+ main(args)
diff --git a/RealESRGANv030/scripts/generate_multiscale_DF2K.py b/RealESRGANv030/scripts/generate_multiscale_DF2K.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4f5d8324b1624e4cb6163754703b8dac2d188fd
--- /dev/null
+++ b/RealESRGANv030/scripts/generate_multiscale_DF2K.py
@@ -0,0 +1,48 @@
+import argparse
+import glob
+import os
+from PIL import Image
+
+
+def main(args):
+ # For DF2K, we consider the following three scales,
+ # and the smallest image whose shortest edge is 400
+ scale_list = [0.75, 0.5, 1 / 3]
+ shortest_edge = 400
+
+ path_list = sorted(glob.glob(os.path.join(args.input, '*')))
+ for path in path_list:
+ print(path)
+ basename = os.path.splitext(os.path.basename(path))[0]
+
+ img = Image.open(path)
+ width, height = img.size
+ for idx, scale in enumerate(scale_list):
+ print(f'\t{scale:.2f}')
+ rlt = img.resize((int(width * scale), int(height * scale)), resample=Image.LANCZOS)
+ rlt.save(os.path.join(args.output, f'{basename}T{idx}.png'))
+
+ # save the smallest image which the shortest edge is 400
+ if width < height:
+ ratio = height / width
+ width = shortest_edge
+ height = int(width * ratio)
+ else:
+ ratio = width / height
+ height = shortest_edge
+ width = int(height * ratio)
+ rlt = img.resize((int(width), int(height)), resample=Image.LANCZOS)
+ rlt.save(os.path.join(args.output, f'{basename}T{idx+1}.png'))
+
+
+if __name__ == '__main__':
+ """Generate multi-scale versions for GT images with LANCZOS resampling.
+ It is now used for DF2K dataset (DIV2K + Flickr 2K)
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
+ parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder')
+ args = parser.parse_args()
+
+ os.makedirs(args.output, exist_ok=True)
+ main(args)
diff --git a/RealESRGANv030/scripts/pytorch2onnx.py b/RealESRGANv030/scripts/pytorch2onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..09d99b2e0171265e70e7507ed8e882b616b449a1
--- /dev/null
+++ b/RealESRGANv030/scripts/pytorch2onnx.py
@@ -0,0 +1,36 @@
+import argparse
+import torch
+import torch.onnx
+from basicsr.archs.rrdbnet_arch import RRDBNet
+
+
+def main(args):
+ # An instance of the model
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+ if args.params:
+ keyname = 'params'
+ else:
+ keyname = 'params_ema'
+ model.load_state_dict(torch.load(args.input)[keyname])
+ # set the train mode to false since we will only run the forward pass.
+ model.train(False)
+ model.cpu().eval()
+
+ # An example input
+ x = torch.rand(1, 3, 64, 64)
+ # Export the model
+ with torch.no_grad():
+ torch_out = torch.onnx._export(model, x, args.output, opset_version=11, export_params=True)
+ print(torch_out.shape)
+
+
+if __name__ == '__main__':
+ """Convert pytorch model to onnx models"""
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--input', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Input model path')
+ parser.add_argument('--output', type=str, default='realesrgan-x4.onnx', help='Output onnx path')
+ parser.add_argument('--params', action='store_false', help='Use params instead of params_ema')
+ args = parser.parse_args()
+
+ main(args)
diff --git a/RealESRGANv030/setup.cfg b/RealESRGANv030/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..9cecd96943e729db110b1960295d9d4bf76c1754
--- /dev/null
+++ b/RealESRGANv030/setup.cfg
@@ -0,0 +1,33 @@
+[flake8]
+ignore =
+ # line break before binary operator (W503)
+ W503,
+ # line break after binary operator (W504)
+ W504,
+max-line-length=120
+
+[yapf]
+based_on_style = pep8
+column_limit = 120
+blank_line_before_nested_class_or_def = true
+split_before_expression_after_opening_paren = true
+
+[isort]
+line_length = 120
+multi_line_output = 0
+known_standard_library = pkg_resources,setuptools
+known_first_party = realesrgan
+known_third_party = PIL,basicsr,cv2,numpy,pytest,torch,torchvision,tqdm,yaml
+no_lines_before = STDLIB,LOCALFOLDER
+default_section = THIRDPARTY
+
+[codespell]
+skip = .git,./docs/build
+count =
+quiet-level = 3
+
+[aliases]
+test=pytest
+
+[tool:pytest]
+addopts=tests/
diff --git a/RealESRGANv030/setup.py b/RealESRGANv030/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2b92e31d2db1aba50767f4f844540cfd53c609d
--- /dev/null
+++ b/RealESRGANv030/setup.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python
+
+from setuptools import find_packages, setup
+
+import os
+import subprocess
+import time
+
+version_file = 'realesrgan/version.py'
+
+
+def readme():
+ with open('README.md', encoding='utf-8') as f:
+ content = f.read()
+ return content
+
+
+def get_git_hash():
+
+ def _minimal_ext_cmd(cmd):
+ # construct minimal environment
+ env = {}
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
+ v = os.environ.get(k)
+ if v is not None:
+ env[k] = v
+ # LANGUAGE is used on win32
+ env['LANGUAGE'] = 'C'
+ env['LANG'] = 'C'
+ env['LC_ALL'] = 'C'
+ out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+ return out
+
+ try:
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
+ sha = out.strip().decode('ascii')
+ except OSError:
+ sha = 'unknown'
+
+ return sha
+
+
+def get_hash():
+ if os.path.exists('.git'):
+ sha = get_git_hash()[:7]
+ else:
+ sha = 'unknown'
+
+ return sha
+
+
+def write_version_py():
+ content = """# GENERATED VERSION FILE
+# TIME: {}
+__version__ = '{}'
+__gitsha__ = '{}'
+version_info = ({})
+"""
+ sha = get_hash()
+ with open('VERSION', 'r') as f:
+ SHORT_VERSION = f.read().strip()
+ VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
+
+ version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
+ with open(version_file, 'w') as f:
+ f.write(version_file_str)
+
+
+def get_version():
+ with open(version_file, 'r') as f:
+ exec(compile(f.read(), version_file, 'exec'))
+ return locals()['__version__']
+
+
+def get_requirements(filename='requirements.txt'):
+ here = os.path.dirname(os.path.realpath(__file__))
+ with open(os.path.join(here, filename), 'r') as f:
+ requires = [line.replace('\n', '') for line in f.readlines()]
+ return requires
+
+
+if __name__ == '__main__':
+ write_version_py()
+ setup(
+ name='realesrgan',
+ version=get_version(),
+ description='Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration',
+ long_description=readme(),
+ long_description_content_type='text/markdown',
+ author='Xintao Wang',
+ author_email='xintao.wang@outlook.com',
+ keywords='computer vision, pytorch, image restoration, super-resolution, esrgan, real-esrgan',
+ url='https://github.com/xinntao/Real-ESRGAN',
+ include_package_data=True,
+ packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
+ classifiers=[
+ 'Development Status :: 4 - Beta',
+ 'License :: OSI Approved :: Apache Software License',
+ 'Operating System :: OS Independent',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ ],
+ license='BSD-3-Clause License',
+ setup_requires=['cython', 'numpy'],
+ install_requires=get_requirements(),
+ zip_safe=False)
diff --git a/RealESRGANv030/tests/data/gt.lmdb/data.mdb b/RealESRGANv030/tests/data/gt.lmdb/data.mdb
new file mode 100644
index 0000000000000000000000000000000000000000..f28ad48dd320c1b624cdd30f492cd8fd0c1c9fab
Binary files /dev/null and b/RealESRGANv030/tests/data/gt.lmdb/data.mdb differ
diff --git a/RealESRGANv030/tests/data/gt.lmdb/lock.mdb b/RealESRGANv030/tests/data/gt.lmdb/lock.mdb
new file mode 100644
index 0000000000000000000000000000000000000000..37b3f72fa44829db318abca1f9495d73d7d6e071
Binary files /dev/null and b/RealESRGANv030/tests/data/gt.lmdb/lock.mdb differ
diff --git a/RealESRGANv030/tests/data/gt.lmdb/meta_info.txt b/RealESRGANv030/tests/data/gt.lmdb/meta_info.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f42295426c1783261024a409005ee693c798951f
--- /dev/null
+++ b/RealESRGANv030/tests/data/gt.lmdb/meta_info.txt
@@ -0,0 +1,2 @@
+baboon.png (480,500,3) 1
+comic.png (360,240,3) 1
diff --git a/RealESRGANv030/tests/data/gt/baboon.png b/RealESRGANv030/tests/data/gt/baboon.png
new file mode 100644
index 0000000000000000000000000000000000000000..c81e18de0346d8801f44495f267148919f6ac70a
Binary files /dev/null and b/RealESRGANv030/tests/data/gt/baboon.png differ
diff --git a/RealESRGANv030/tests/data/gt/comic.png b/RealESRGANv030/tests/data/gt/comic.png
new file mode 100644
index 0000000000000000000000000000000000000000..600f5486503b53b77323a7c28b53822b23d576ba
Binary files /dev/null and b/RealESRGANv030/tests/data/gt/comic.png differ
diff --git a/RealESRGANv030/tests/data/lq.lmdb/data.mdb b/RealESRGANv030/tests/data/lq.lmdb/data.mdb
new file mode 100644
index 0000000000000000000000000000000000000000..c0162153452f63afbc798e99bfcdc1a6866caa0a
Binary files /dev/null and b/RealESRGANv030/tests/data/lq.lmdb/data.mdb differ
diff --git a/RealESRGANv030/tests/data/lq.lmdb/lock.mdb b/RealESRGANv030/tests/data/lq.lmdb/lock.mdb
new file mode 100644
index 0000000000000000000000000000000000000000..c3b69ed59644c8337389f82010234aab8f688b09
Binary files /dev/null and b/RealESRGANv030/tests/data/lq.lmdb/lock.mdb differ
diff --git a/RealESRGANv030/tests/data/lq.lmdb/meta_info.txt b/RealESRGANv030/tests/data/lq.lmdb/meta_info.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6dfca0d9de4717a97db69167f020f34d8da6c0d0
--- /dev/null
+++ b/RealESRGANv030/tests/data/lq.lmdb/meta_info.txt
@@ -0,0 +1,2 @@
+baboon.png (120,125,3) 1
+comic.png (80,60,3) 1
diff --git a/RealESRGANv030/tests/data/lq/baboon.png b/RealESRGANv030/tests/data/lq/baboon.png
new file mode 100644
index 0000000000000000000000000000000000000000..bbd201245f3bb1736bc35820eb28f0d59eef766f
Binary files /dev/null and b/RealESRGANv030/tests/data/lq/baboon.png differ
diff --git a/RealESRGANv030/tests/data/lq/comic.png b/RealESRGANv030/tests/data/lq/comic.png
new file mode 100644
index 0000000000000000000000000000000000000000..c4e38ab76ecb80deb84fdc8f16f5afa009d95ddd
Binary files /dev/null and b/RealESRGANv030/tests/data/lq/comic.png differ
diff --git a/RealESRGANv030/tests/data/meta_info_gt.txt b/RealESRGANv030/tests/data/meta_info_gt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2234632d9ed7db237273779fe7cd6ddcbee4e67f
--- /dev/null
+++ b/RealESRGANv030/tests/data/meta_info_gt.txt
@@ -0,0 +1,2 @@
+baboon.png
+comic.png
diff --git a/RealESRGANv030/tests/data/meta_info_pair.txt b/RealESRGANv030/tests/data/meta_info_pair.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4775dda818d25e1a2ebf67c98df571b26d87c912
--- /dev/null
+++ b/RealESRGANv030/tests/data/meta_info_pair.txt
@@ -0,0 +1,2 @@
+gt/baboon.png, lq/baboon.png
+gt/comic.png, lq/comic.png
diff --git a/RealESRGANv030/tests/data/test_realesrgan_dataset.yml b/RealESRGANv030/tests/data/test_realesrgan_dataset.yml
new file mode 100644
index 0000000000000000000000000000000000000000..48e6ecc338e730e74ed5a24aefb66ea5e45381e7
--- /dev/null
+++ b/RealESRGANv030/tests/data/test_realesrgan_dataset.yml
@@ -0,0 +1,28 @@
+name: Demo
+type: RealESRGANDataset
+dataroot_gt: tests/data/gt
+meta_info: tests/data/meta_info_gt.txt
+io_backend:
+ type: disk
+
+blur_kernel_size: 21
+kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+sinc_prob: 1
+blur_sigma: [0.2, 3]
+betag_range: [0.5, 4]
+betap_range: [1, 2]
+
+blur_kernel_size2: 21
+kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+sinc_prob2: 1
+blur_sigma2: [0.2, 1.5]
+betag_range2: [0.5, 4]
+betap_range2: [1, 2]
+
+final_sinc_prob: 1
+
+gt_size: 128
+use_hflip: True
+use_rot: False
diff --git a/RealESRGANv030/tests/data/test_realesrgan_model.yml b/RealESRGANv030/tests/data/test_realesrgan_model.yml
new file mode 100644
index 0000000000000000000000000000000000000000..1cbdab23be5cf973c4bea66a85ac0dca2c1d713e
--- /dev/null
+++ b/RealESRGANv030/tests/data/test_realesrgan_model.yml
@@ -0,0 +1,115 @@
+scale: 4
+num_gpu: 1
+manual_seed: 0
+is_train: True
+dist: False
+
+# ----------------- options for synthesizing training data ----------------- #
+# USM the ground-truth
+l1_gt_usm: True
+percep_gt_usm: True
+gan_gt_usm: False
+
+# the first degradation process
+resize_prob: [0.2, 0.7, 0.1] # up, down, keep
+resize_range: [0.15, 1.5]
+gaussian_noise_prob: 1
+noise_range: [1, 30]
+poisson_scale_range: [0.05, 3]
+gray_noise_prob: 1
+jpeg_range: [30, 95]
+
+# the second degradation process
+second_blur_prob: 1
+resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
+resize_range2: [0.3, 1.2]
+gaussian_noise_prob2: 1
+noise_range2: [1, 25]
+poisson_scale_range2: [0.05, 2.5]
+gray_noise_prob2: 1
+jpeg_range2: [30, 95]
+
+gt_size: 32
+queue_size: 1
+
+# network structures
+network_g:
+ type: RRDBNet
+ num_in_ch: 3
+ num_out_ch: 3
+ num_feat: 4
+ num_block: 1
+ num_grow_ch: 2
+
+network_d:
+ type: UNetDiscriminatorSN
+ num_in_ch: 3
+ num_feat: 2
+ skip_connection: True
+
+# path
+path:
+ pretrain_network_g: ~
+ param_key_g: params_ema
+ strict_load_g: true
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0.999
+ optim_g:
+ type: Adam
+ lr: !!float 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+ optim_d:
+ type: Adam
+ lr: !!float 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [400000]
+ gamma: 0.5
+
+ total_iter: 400000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: 1.0
+ reduction: mean
+ # perceptual loss (content and style losses)
+ perceptual_opt:
+ type: PerceptualLoss
+ layer_weights:
+ # before relu
+ 'conv1_2': 0.1
+ 'conv2_2': 0.1
+ 'conv3_4': 1
+ 'conv4_4': 1
+ 'conv5_4': 1
+ vgg_type: vgg19
+ use_input_norm: true
+ perceptual_weight: !!float 1.0
+ style_weight: 0
+ range_norm: false
+ criterion: l1
+ # gan loss
+ gan_opt:
+ type: GANLoss
+ gan_type: vanilla
+ real_label_val: 1.0
+ fake_label_val: 0.0
+ loss_weight: !!float 1e-1
+
+ net_d_iters: 1
+ net_d_init_iters: 0
+
+
+# validation settings
+val:
+ val_freq: !!float 5e3
+ save_img: False
diff --git a/RealESRGANv030/tests/data/test_realesrgan_paired_dataset.yml b/RealESRGANv030/tests/data/test_realesrgan_paired_dataset.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8ea9709d214852ae8f792e3ee732edf542dc382d
--- /dev/null
+++ b/RealESRGANv030/tests/data/test_realesrgan_paired_dataset.yml
@@ -0,0 +1,13 @@
+name: Demo
+type: RealESRGANPairedDataset
+scale: 4
+dataroot_gt: tests/data
+dataroot_lq: tests/data
+meta_info: tests/data/meta_info_pair.txt
+io_backend:
+ type: disk
+
+phase: train
+gt_size: 128
+use_hflip: True
+use_rot: False
diff --git a/RealESRGANv030/tests/data/test_realesrnet_model.yml b/RealESRGANv030/tests/data/test_realesrnet_model.yml
new file mode 100644
index 0000000000000000000000000000000000000000..06ceb26f4df3cad96ea8d00cf1ede0dc85d5b8d4
--- /dev/null
+++ b/RealESRGANv030/tests/data/test_realesrnet_model.yml
@@ -0,0 +1,75 @@
+scale: 4
+num_gpu: 1
+manual_seed: 0
+is_train: True
+dist: False
+
+# ----------------- options for synthesizing training data ----------------- #
+gt_usm: True # USM the ground-truth
+
+# the first degradation process
+resize_prob: [0.2, 0.7, 0.1] # up, down, keep
+resize_range: [0.15, 1.5]
+gaussian_noise_prob: 1
+noise_range: [1, 30]
+poisson_scale_range: [0.05, 3]
+gray_noise_prob: 1
+jpeg_range: [30, 95]
+
+# the second degradation process
+second_blur_prob: 1
+resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
+resize_range2: [0.3, 1.2]
+gaussian_noise_prob2: 1
+noise_range2: [1, 25]
+poisson_scale_range2: [0.05, 2.5]
+gray_noise_prob2: 1
+jpeg_range2: [30, 95]
+
+gt_size: 32
+queue_size: 1
+
+# network structures
+network_g:
+ type: RRDBNet
+ num_in_ch: 3
+ num_out_ch: 3
+ num_feat: 4
+ num_block: 1
+ num_grow_ch: 2
+
+# path
+path:
+ pretrain_network_g: ~
+ param_key_g: params_ema
+ strict_load_g: true
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0.999
+ optim_g:
+ type: Adam
+ lr: !!float 2e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [1000000]
+ gamma: 0.5
+
+ total_iter: 1000000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: 1.0
+ reduction: mean
+
+
+# validation settings
+val:
+ val_freq: !!float 5e3
+ save_img: False
diff --git a/RealESRGANv030/tests/test_dataset.py b/RealESRGANv030/tests/test_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..715b4082645c131d43d728ae8f65bcc2430aa8c9
--- /dev/null
+++ b/RealESRGANv030/tests/test_dataset.py
@@ -0,0 +1,151 @@
+import pytest
+import yaml
+
+from realesrgan.data.realesrgan_dataset import RealESRGANDataset
+from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset
+
+
+def test_realesrgan_dataset():
+
+ with open('tests/data/test_realesrgan_dataset.yml', mode='r') as f:
+ opt = yaml.load(f, Loader=yaml.FullLoader)
+
+ dataset = RealESRGANDataset(opt)
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
+ assert len(dataset) == 2 # whether to read correct meta info
+ assert dataset.kernel_list == [
+ 'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
+ ] # correct initialization the degradation configurations
+ assert dataset.betag_range2 == [0.5, 4]
+
+ # test __getitem__
+ result = dataset.__getitem__(0)
+ # check returned keys
+ expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
+ assert set(expected_keys).issubset(set(result.keys()))
+ # check shape and contents
+ assert result['gt'].shape == (3, 400, 400)
+ assert result['kernel1'].shape == (21, 21)
+ assert result['kernel2'].shape == (21, 21)
+ assert result['sinc_kernel'].shape == (21, 21)
+ assert result['gt_path'] == 'tests/data/gt/baboon.png'
+
+ # ------------------ test lmdb backend -------------------- #
+ opt['dataroot_gt'] = 'tests/data/gt.lmdb'
+ opt['io_backend']['type'] = 'lmdb'
+
+ dataset = RealESRGANDataset(opt)
+ assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
+ assert len(dataset.paths) == 2 # whether to read correct meta info
+ assert dataset.kernel_list == [
+ 'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
+ ] # correct initialization the degradation configurations
+ assert dataset.betag_range2 == [0.5, 4]
+
+ # test __getitem__
+ result = dataset.__getitem__(1)
+ # check returned keys
+ expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
+ assert set(expected_keys).issubset(set(result.keys()))
+ # check shape and contents
+ assert result['gt'].shape == (3, 400, 400)
+ assert result['kernel1'].shape == (21, 21)
+ assert result['kernel2'].shape == (21, 21)
+ assert result['sinc_kernel'].shape == (21, 21)
+ assert result['gt_path'] == 'comic'
+
+ # ------------------ test with sinc_prob = 0 -------------------- #
+ opt['dataroot_gt'] = 'tests/data/gt.lmdb'
+ opt['io_backend']['type'] = 'lmdb'
+ opt['sinc_prob'] = 0
+ opt['sinc_prob2'] = 0
+ opt['final_sinc_prob'] = 0
+ dataset = RealESRGANDataset(opt)
+ result = dataset.__getitem__(0)
+ # check returned keys
+ expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
+ assert set(expected_keys).issubset(set(result.keys()))
+ # check shape and contents
+ assert result['gt'].shape == (3, 400, 400)
+ assert result['kernel1'].shape == (21, 21)
+ assert result['kernel2'].shape == (21, 21)
+ assert result['sinc_kernel'].shape == (21, 21)
+ assert result['gt_path'] == 'baboon'
+
+ # ------------------ lmdb backend should have paths ends with lmdb -------------------- #
+ with pytest.raises(ValueError):
+ opt['dataroot_gt'] = 'tests/data/gt'
+ opt['io_backend']['type'] = 'lmdb'
+ dataset = RealESRGANDataset(opt)
+
+
+def test_realesrgan_paired_dataset():
+
+ with open('tests/data/test_realesrgan_paired_dataset.yml', mode='r') as f:
+ opt = yaml.load(f, Loader=yaml.FullLoader)
+
+ dataset = RealESRGANPairedDataset(opt)
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
+ assert len(dataset) == 2 # whether to read correct meta info
+
+ # test __getitem__
+ result = dataset.__getitem__(0)
+ # check returned keys
+ expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
+ assert set(expected_keys).issubset(set(result.keys()))
+ # check shape and contents
+ assert result['gt'].shape == (3, 128, 128)
+ assert result['lq'].shape == (3, 32, 32)
+ assert result['gt_path'] == 'tests/data/gt/baboon.png'
+ assert result['lq_path'] == 'tests/data/lq/baboon.png'
+
+ # ------------------ test lmdb backend -------------------- #
+ opt['dataroot_gt'] = 'tests/data/gt.lmdb'
+ opt['dataroot_lq'] = 'tests/data/lq.lmdb'
+ opt['io_backend']['type'] = 'lmdb'
+
+ dataset = RealESRGANPairedDataset(opt)
+ assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
+ assert len(dataset) == 2 # whether to read correct meta info
+
+ # test __getitem__
+ result = dataset.__getitem__(1)
+ # check returned keys
+ expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
+ assert set(expected_keys).issubset(set(result.keys()))
+ # check shape and contents
+ assert result['gt'].shape == (3, 128, 128)
+ assert result['lq'].shape == (3, 32, 32)
+ assert result['gt_path'] == 'comic'
+ assert result['lq_path'] == 'comic'
+
+ # ------------------ test paired_paths_from_folder -------------------- #
+ opt['dataroot_gt'] = 'tests/data/gt'
+ opt['dataroot_lq'] = 'tests/data/lq'
+ opt['io_backend'] = dict(type='disk')
+ opt['meta_info'] = None
+
+ dataset = RealESRGANPairedDataset(opt)
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
+ assert len(dataset) == 2 # whether to read correct meta info
+
+ # test __getitem__
+ result = dataset.__getitem__(0)
+ # check returned keys
+ expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
+ assert set(expected_keys).issubset(set(result.keys()))
+ # check shape and contents
+ assert result['gt'].shape == (3, 128, 128)
+ assert result['lq'].shape == (3, 32, 32)
+
+ # ------------------ test normalization -------------------- #
+ dataset.mean = [0.5, 0.5, 0.5]
+ dataset.std = [0.5, 0.5, 0.5]
+ # test __getitem__
+ result = dataset.__getitem__(0)
+ # check returned keys
+ expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
+ assert set(expected_keys).issubset(set(result.keys()))
+ # check shape and contents
+ assert result['gt'].shape == (3, 128, 128)
+ assert result['lq'].shape == (3, 32, 32)
diff --git a/RealESRGANv030/tests/test_discriminator_arch.py b/RealESRGANv030/tests/test_discriminator_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c56a40c7743630aa63b3e99bca8dc1a85949c4c5
--- /dev/null
+++ b/RealESRGANv030/tests/test_discriminator_arch.py
@@ -0,0 +1,19 @@
+import torch
+
+from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
+
+
+def test_unetdiscriminatorsn():
+ """Test arch: UNetDiscriminatorSN."""
+
+ # model init and forward (cpu)
+ net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
+ output = net(img)
+ assert output.shape == (1, 1, 32, 32)
+
+ # model init and forward (gpu)
+ if torch.cuda.is_available():
+ net.cuda()
+ output = net(img.cuda())
+ assert output.shape == (1, 1, 32, 32)
diff --git a/RealESRGANv030/tests/test_model.py b/RealESRGANv030/tests/test_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c20bb1d56ed20222e929e9c94026f6ea383c6026
--- /dev/null
+++ b/RealESRGANv030/tests/test_model.py
@@ -0,0 +1,126 @@
+import torch
+import yaml
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from basicsr.data.paired_image_dataset import PairedImageDataset
+from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
+
+from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
+from realesrgan.models.realesrgan_model import RealESRGANModel
+from realesrgan.models.realesrnet_model import RealESRNetModel
+
+
+def test_realesrnet_model():
+ with open('tests/data/test_realesrnet_model.yml', mode='r') as f:
+ opt = yaml.load(f, Loader=yaml.FullLoader)
+
+ # build model
+ model = RealESRNetModel(opt)
+ # test attributes
+ assert model.__class__.__name__ == 'RealESRNetModel'
+ assert isinstance(model.net_g, RRDBNet)
+ assert isinstance(model.cri_pix, L1Loss)
+ assert isinstance(model.optimizers[0], torch.optim.Adam)
+
+ # prepare data
+ gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
+ kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
+ kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
+ sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
+ data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
+ model.feed_data(data)
+ # check dequeue
+ model.feed_data(data)
+ # check data shape
+ assert model.lq.shape == (1, 3, 8, 8)
+ assert model.gt.shape == (1, 3, 32, 32)
+
+ # change probability to test if-else
+ model.opt['gaussian_noise_prob'] = 0
+ model.opt['gray_noise_prob'] = 0
+ model.opt['second_blur_prob'] = 0
+ model.opt['gaussian_noise_prob2'] = 0
+ model.opt['gray_noise_prob2'] = 0
+ model.feed_data(data)
+ # check data shape
+ assert model.lq.shape == (1, 3, 8, 8)
+ assert model.gt.shape == (1, 3, 32, 32)
+
+ # ----------------- test nondist_validation -------------------- #
+ # construct dataloader
+ dataset_opt = dict(
+ name='Demo',
+ dataroot_gt='tests/data/gt',
+ dataroot_lq='tests/data/lq',
+ io_backend=dict(type='disk'),
+ scale=4,
+ phase='val')
+ dataset = PairedImageDataset(dataset_opt)
+ dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+ assert model.is_train is True
+ model.nondist_validation(dataloader, 1, None, False)
+ assert model.is_train is True
+
+
+def test_realesrgan_model():
+ with open('tests/data/test_realesrgan_model.yml', mode='r') as f:
+ opt = yaml.load(f, Loader=yaml.FullLoader)
+
+ # build model
+ model = RealESRGANModel(opt)
+ # test attributes
+ assert model.__class__.__name__ == 'RealESRGANModel'
+ assert isinstance(model.net_g, RRDBNet) # generator
+ assert isinstance(model.net_d, UNetDiscriminatorSN) # discriminator
+ assert isinstance(model.cri_pix, L1Loss)
+ assert isinstance(model.cri_perceptual, PerceptualLoss)
+ assert isinstance(model.cri_gan, GANLoss)
+ assert isinstance(model.optimizers[0], torch.optim.Adam)
+ assert isinstance(model.optimizers[1], torch.optim.Adam)
+
+ # prepare data
+ gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
+ kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
+ kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
+ sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
+ data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
+ model.feed_data(data)
+ # check dequeue
+ model.feed_data(data)
+ # check data shape
+ assert model.lq.shape == (1, 3, 8, 8)
+ assert model.gt.shape == (1, 3, 32, 32)
+
+ # change probability to test if-else
+ model.opt['gaussian_noise_prob'] = 0
+ model.opt['gray_noise_prob'] = 0
+ model.opt['second_blur_prob'] = 0
+ model.opt['gaussian_noise_prob2'] = 0
+ model.opt['gray_noise_prob2'] = 0
+ model.feed_data(data)
+ # check data shape
+ assert model.lq.shape == (1, 3, 8, 8)
+ assert model.gt.shape == (1, 3, 32, 32)
+
+ # ----------------- test nondist_validation -------------------- #
+ # construct dataloader
+ dataset_opt = dict(
+ name='Demo',
+ dataroot_gt='tests/data/gt',
+ dataroot_lq='tests/data/lq',
+ io_backend=dict(type='disk'),
+ scale=4,
+ phase='val')
+ dataset = PairedImageDataset(dataset_opt)
+ dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+ assert model.is_train is True
+ model.nondist_validation(dataloader, 1, None, False)
+ assert model.is_train is True
+
+ # ----------------- test optimize_parameters -------------------- #
+ model.feed_data(data)
+ model.optimize_parameters(1)
+ assert model.output.shape == (1, 3, 32, 32)
+ assert isinstance(model.log_dict, dict)
+ # check returned keys
+ expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake']
+ assert set(expected_keys).issubset(set(model.log_dict.keys()))
diff --git a/RealESRGANv030/tests/test_utils.py b/RealESRGANv030/tests/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7919b74905495b4b6f4aa957a1f0b5d7a174c782
--- /dev/null
+++ b/RealESRGANv030/tests/test_utils.py
@@ -0,0 +1,87 @@
+import numpy as np
+from basicsr.archs.rrdbnet_arch import RRDBNet
+
+from realesrgan.utils import RealESRGANer
+
+
+def test_realesrganer():
+ # initialize with default model
+ restorer = RealESRGANer(
+ scale=4,
+ model_path='experiments/pretrained_models/RealESRGAN_x4plus.pth',
+ model=None,
+ tile=10,
+ tile_pad=10,
+ pre_pad=2,
+ half=False)
+ assert isinstance(restorer.model, RRDBNet)
+ assert restorer.half is False
+ # initialize with user-defined model
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+ restorer = RealESRGANer(
+ scale=4,
+ model_path='experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth',
+ model=model,
+ tile=10,
+ tile_pad=10,
+ pre_pad=2,
+ half=True)
+ # test attribute
+ assert isinstance(restorer.model, RRDBNet)
+ assert restorer.half is True
+
+ # ------------------ test pre_process ---------------- #
+ img = np.random.random((12, 12, 3)).astype(np.float32)
+ restorer.pre_process(img)
+ assert restorer.img.shape == (1, 3, 14, 14)
+ # with modcrop
+ restorer.scale = 1
+ restorer.pre_process(img)
+ assert restorer.img.shape == (1, 3, 16, 16)
+
+ # ------------------ test process ---------------- #
+ restorer.process()
+ assert restorer.output.shape == (1, 3, 64, 64)
+
+ # ------------------ test post_process ---------------- #
+ restorer.mod_scale = 4
+ output = restorer.post_process()
+ assert output.shape == (1, 3, 60, 60)
+
+ # ------------------ test tile_process ---------------- #
+ restorer.scale = 4
+ img = np.random.random((12, 12, 3)).astype(np.float32)
+ restorer.pre_process(img)
+ restorer.tile_process()
+ assert restorer.output.shape == (1, 3, 64, 64)
+
+ # ------------------ test enhance ---------------- #
+ img = np.random.random((12, 12, 3)).astype(np.float32)
+ result = restorer.enhance(img, outscale=2)
+ assert result[0].shape == (24, 24, 3)
+ assert result[1] == 'RGB'
+
+ # ------------------ test enhance with 16-bit image---------------- #
+ img = np.random.random((4, 4, 3)).astype(np.uint16) + 512
+ result = restorer.enhance(img, outscale=2)
+ assert result[0].shape == (8, 8, 3)
+ assert result[1] == 'RGB'
+
+ # ------------------ test enhance with gray image---------------- #
+ img = np.random.random((4, 4)).astype(np.float32)
+ result = restorer.enhance(img, outscale=2)
+ assert result[0].shape == (8, 8)
+ assert result[1] == 'L'
+
+ # ------------------ test enhance with RGBA---------------- #
+ img = np.random.random((4, 4, 4)).astype(np.float32)
+ result = restorer.enhance(img, outscale=2)
+ assert result[0].shape == (8, 8, 4)
+ assert result[1] == 'RGBA'
+
+ # ------------------ test enhance with RGBA, alpha_upsampler---------------- #
+ restorer.tile_size = 0
+ img = np.random.random((4, 4, 4)).astype(np.float32)
+ result = restorer.enhance(img, outscale=2, alpha_upsampler=None)
+ assert result[0].shape == (8, 8, 4)
+ assert result[1] == 'RGBA'
diff --git a/RealESRGANv030/weights/README.md b/RealESRGANv030/weights/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4d7b7e642591ef88575d9e6c360a4d29e0cc1a4f
--- /dev/null
+++ b/RealESRGANv030/weights/README.md
@@ -0,0 +1,3 @@
+# Weights
+
+Put the downloaded weights to this folder.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ef6a0bc74cf1f12cf0b5b73724a6c2d84eb5eb8
--- /dev/null
+++ b/app.py
@@ -0,0 +1,366 @@
+import os
+import random
+
+import autocuda
+from pyabsa.utils.pyabsa_utils import fprint
+
+from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, \
+ DPMSolverMultistepScheduler
+import gradio as gr
+import torch
+from PIL import Image
+import utils
+import datetime
+import time
+import psutil
+
+from RealESRGANv030.interface import realEsrgan
+
+start_time = time.time()
+is_colab = utils.is_google_colab()
+
+device = autocuda.auto_cuda()
+dtype = torch.float16 if device != 'cpu' else torch.float32
+
+class Model:
+ def __init__(self, name, path="", prefix=""):
+ self.name = name
+ self.path = path
+ self.prefix = prefix
+ self.pipe_t2i = None
+ self.pipe_i2i = None
+
+
+models = [
+ Model("anything v3", "Linaqruf/anything-v3.0", "anything v3 style"),
+]
+# Model("Spider-Verse", "nitrosocke/spider-verse-diffusion", "spiderverse style "),
+# Model("Balloon Art", "Fictiverse/Stable_Diffusion_BalloonArt_Model", "BalloonArt "),
+# Model("Elden Ring", "nitrosocke/elden-ring-diffusion", "elden ring style "),
+# Model("Tron Legacy", "dallinmackay/Tron-Legacy-diffusion", "trnlgcy ")
+# Model("Pokémon", "lambdalabs/sd-pokemon-diffusers", ""),
+# Model("Pony Diffusion", "AstraliteHeart/pony-diffusion", ""),
+# Model("Robo Diffusion", "nousr/robo-diffusion", ""),
+
+scheduler = DPMSolverMultistepScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ num_train_timesteps=1000,
+ trained_betas=None,
+ predict_epsilon=True,
+ thresholding=False,
+ algorithm_type="dpmsolver++",
+ solver_type="midpoint",
+ lower_order_final=True,
+)
+
+custom_model = None
+if is_colab:
+ models.insert(0, Model("Custom model"))
+ custom_model = models[0]
+
+last_mode = "txt2img"
+current_model = models[1] if is_colab else models[0]
+current_model_path = current_model.path
+
+if is_colab:
+ pipe = StableDiffusionPipeline.from_pretrained(current_model.path, torch_dtype=dtype, scheduler=scheduler,
+ safety_checker=lambda images, clip_input: (images, False))
+
+else: # download all models
+ print(f"{datetime.datetime.now()} Downloading vae...")
+ vae = AutoencoderKL.from_pretrained(current_model.path, subfolder="vae", torch_dtype=dtype)
+ for model in models:
+ try:
+ print(f"{datetime.datetime.now()} Downloading {model.name} model...")
+ unet = UNet2DConditionModel.from_pretrained(model.path, subfolder="unet", torch_dtype=dtype)
+ model.pipe_t2i = StableDiffusionPipeline.from_pretrained(model.path, unet=unet, vae=vae,
+ torch_dtype=dtype, scheduler=scheduler,
+ safety_checker=None)
+ model.pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(model.path, unet=unet, vae=vae,
+ torch_dtype=dtype,
+ scheduler=scheduler, safety_checker=None)
+ except Exception as e:
+ print(f"{datetime.datetime.now()} Failed to load model " + model.name + ": " + str(e))
+ models.remove(model)
+ pipe = models[0].pipe_t2i
+
+# model.pipe_i2i = torch.compile(model.pipe_i2i)
+# model.pipe_t2i = torch.compile(model.pipe_t2i)
+if torch.cuda.is_available():
+ pipe = pipe.to(device)
+
+
+# device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
+
+
+def error_str(error, title="Error"):
+ return f"""#### {title}
+ {error}""" if error else ""
+
+
+def custom_model_changed(path):
+ models[0].path = path
+ global current_model
+ current_model = models[0]
+
+
+def on_model_change(model_name):
+ prefix = "Enter prompt. \"" + next((m.prefix for m in models if m.name == model_name),
+ None) + "\" is prefixed automatically" if model_name != models[
+ 0].name else "Don't forget to use the custom model prefix in the prompt!"
+
+ return gr.update(visible=model_name == models[0].name), gr.update(placeholder=prefix)
+
+
+def inference(model_name, prompt, guidance, steps, width=512, height=512, seed=0, img=None, strength=0.5,
+ neg_prompt="", scale_factor=2, tile=200):
+ fprint(psutil.virtual_memory()) # print memory usage
+ prompt = 'detailed fingers, beautiful hands,' + prompt
+ fprint(f"Prompt: {prompt}")
+ global current_model
+ for model in models:
+ if model.name == model_name:
+ current_model = model
+ model_path = current_model.path
+
+ generator = torch.Generator(device).manual_seed(seed) if seed != 0 else None
+
+ try:
+ if img is not None:
+ return img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, width, height,
+ generator, scale_factor, tile), None
+ else:
+ return txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, generator,
+ scale_factor, tile), None
+ except Exception as e:
+ return None, error_str(e)
+ # if img is not None:
+ # return img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, width, height,
+ # generator, scale_factor), None
+ # else:
+ # return txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, generator, scale_factor), None
+
+
+def txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, generator, scale_factor, tile):
+ print(f"{datetime.datetime.now()} txt_to_img, model: {current_model.name}")
+
+ global last_mode
+ global pipe
+ global current_model_path
+ if model_path != current_model_path or last_mode != "txt2img":
+ current_model_path = model_path
+
+ if is_colab or current_model == custom_model:
+ pipe = StableDiffusionPipeline.from_pretrained(current_model_path, torch_dtype=dtype,
+ scheduler=scheduler,
+ safety_checker=lambda images, clip_input: (images, False))
+ else:
+ pipe = current_model.pipe_t2i
+
+ if torch.cuda.is_available():
+ pipe = pipe.to(device)
+ last_mode = "txt2img"
+
+ prompt = current_model.prefix + prompt
+ result = pipe(
+ prompt,
+ negative_prompt=neg_prompt,
+ # num_images_per_prompt=n_images,
+ num_inference_steps=int(steps),
+ guidance_scale=guidance,
+ width=width,
+ height=height,
+ generator=generator)
+ # result.images[0] = magnifier.magnify(result.images[0], scale_factor=scale_factor)
+
+ # save image
+ img_file = "imgs/result-{}.png".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
+ result.images[0].save(img_file)
+
+ # enhance resolution
+ fp32 = True if device=='cpu' else False
+ result.images[0] = realEsrgan(
+ input_dir = img_file,
+ suffix = '',
+ output_dir= "imgs",
+ fp32 = fp32,
+ outscale = scale_factor,
+ tile = tile
+ )[0]
+ print('Complete')
+
+ return replace_nsfw_images(result)
+
+
+def img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, width, height, generator, scale_factor, tile):
+ fprint(f"{datetime.datetime.now()} img_to_img, model: {model_path}")
+
+ global last_mode
+ global pipe
+ global current_model_path
+ if model_path != current_model_path or last_mode != "img2img":
+ current_model_path = model_path
+
+ if is_colab or current_model == custom_model:
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(current_model_path, torch_dtype=dtype,
+ scheduler=scheduler,
+ safety_checker=lambda images, clip_input: (
+ images, False))
+ else:
+ # pipe = pipe.to("cpu")
+ pipe = current_model.pipe_i2i
+
+ if torch.cuda.is_available():
+ pipe = pipe.to(device)
+ last_mode = "img2img"
+
+ prompt = current_model.prefix + prompt
+ ratio = min(height / img.height, width / img.width)
+ img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS)
+ result = pipe(
+ prompt,
+ negative_prompt=neg_prompt,
+ # num_images_per_prompt=n_images,
+ image=img,
+ num_inference_steps=int(steps),
+ strength=strength,
+ guidance_scale=guidance,
+ # width=width,
+ # height=height,
+ generator=generator)
+
+ # save image
+ result.images[0].save("imgs/result-{}.png".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")))
+
+ # enhance resolution
+ fp32 = True if device=='cpu' else False
+ result.images[0] = realEsrgan(
+ input_dir = img_file,
+ suffix = '',
+ output_dir= "imgs",
+ fp32 = fp32,
+ outscale = scale_factor,
+ tile = tile
+ )[0]
+ print('Complete')
+
+ return replace_nsfw_images(result)
+
+
+def replace_nsfw_images(results):
+ if is_colab:
+ return results.images[0]
+ if hasattr(results, "nsfw_content_detected") and results.nsfw_content_detected:
+ for i in range(len(results.images)):
+ if results.nsfw_content_detected[i]:
+ results.images[i] = Image.open("nsfw.png")
+ return results.images[0]
+
+
+css = 'style.css'
+with gr.Blocks(css=css) as demo:
+ if not os.path.exists('imgs'):
+ os.mkdir('imgs')
+
+ gr.Markdown('# RealESRGAN enhanced Anime Diffusion')
+ gr.Markdown(
+ "## Author: [dotmet](https://github.com/dotmet) Github:[Github](https://github.com/dotmet/Real-ESRGAN-Enhanced-Anime-Diffusion)")
+ gr.Markdown(
+ "### You can duplicate this demo on HuggingFace Spaces, click [here](https://huggingface.co/spaces/yangheng/Super-Resolution-Anime-Diffusion?duplicate=true)")
+
+ with gr.Row():
+ with gr.Column(scale=55):
+ with gr.Group():
+ gr.Markdown("Text to image")
+
+ model_name = gr.Dropdown(label="Model", choices=[m.name for m in models], value=current_model.name)
+
+ with gr.Box(visible=False) as custom_model_group:
+ custom_model_path = gr.Textbox(label="Custom model path",
+ placeholder="Path to model, e.g. nitrosocke/Arcane-Diffusion",
+ interactive=True)
+ gr.HTML(
+ "Custom models have to be downloaded first, so give it some time.
")
+
+ with gr.Row():
+ prompt = gr.Textbox(label="Prompt", show_label=False, max_lines=2,
+ placeholder="Enter prompt. Style applied automatically").style(container=False)
+ with gr.Row():
+ generate = gr.Button(value="Generate")
+
+ with gr.Row():
+ with gr.Group():
+ neg_prompt = gr.Textbox(label="Negative prompt", placeholder="What to exclude from the image")
+
+ image_out = gr.Image(height=512)
+ # gallery = gr.Gallery(
+ # label="Generated images", show_label=False, elem_id="gallery"
+ # ).style(grid=[1], height="auto")
+ error_output = gr.Markdown()
+
+ with gr.Column(scale=45):
+ with gr.Group():
+ gr.Markdown("Image to Image")
+
+ with gr.Row():
+ with gr.Group():
+ image = gr.Image(label="Image", height=256, tool="editor", type="pil")
+ strength = gr.Slider(label="Transformation strength", minimum=0, maximum=1, step=0.01,
+ value=0.5)
+
+ with gr.Row():
+ with gr.Group():
+ # n_images = gr.Slider(label="Images", value=1, minimum=1, maximum=4, step=1)
+
+ with gr.Row():
+ guidance = gr.Slider(label="Guidance scale", value=7.5, maximum=15)
+ steps = gr.Slider(label="Steps", value=15, minimum=2, maximum=75, step=1)
+
+ with gr.Row():
+ width = gr.Slider(label="Width", value=512, minimum=64, maximum=1024, step=8)
+ height = gr.Slider(label="Height", value=512, minimum=64, maximum=1024, step=8)
+ with gr.Row():
+ scale_factor = gr.Slider(label='Scale factor (to magnify image) (1, 2, 4, 8)',
+ value=1, minimum=1, maximum=8, step=1)
+ with gr.Row():
+ tile = gr.Slider(label='''Tile for magnify
+ (depend on the memory of your device, 0=no tile)''',
+ value=0, minimum=0, maximum=10000, step=10)
+ with gr.Row():
+ seed = gr.Slider(0, 114514, label='Random Seed (0 = random)', value=0, step=1)
+
+ if is_colab:
+ model_name.change(on_model_change, inputs=model_name, outputs=[custom_model_group, prompt], queue=False)
+ custom_model_path.change(custom_model_changed, inputs=custom_model_path, outputs=None)
+ # n_images.change(lambda n: gr.Gallery().style(grid=[2 if n > 1 else 1], height="auto"), inputs=n_images, outputs=gallery)
+
+ gr.Markdown('''### based on [Anything V3](https://huggingface.co/Linaqruf/anything-v3.0)
+ and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN)
+ ''')
+
+ inputs = [model_name, prompt, guidance, steps, width, height, seed, image, strength, neg_prompt, scale_factor, tile]
+ outputs = [image_out, error_output]
+ prompt.submit(inference, inputs=inputs, outputs=outputs)
+ generate.click(inference, inputs=inputs, outputs=outputs, api_name="generate")
+
+ prompt_keys = [
+ 'girl', 'lovely', 'cute', 'beautiful eyes', 'cumulonimbus clouds', 'detailed fingers',
+ random.choice(['dress']),
+ random.choice(['white hair']),
+ random.choice(['blue eyes']),
+ random.choice(['flower meadow']),
+ random.choice(['Elif', 'Angel'])
+ ]
+ prompt.value = ','.join(prompt_keys)
+ ex = gr.Examples([
+ [models[0].name, prompt.value, 7.5, 15],
+
+ ], inputs=[model_name, prompt, guidance, steps, seed], outputs=outputs, fn=inference, cache_examples=False)
+
+print(f"Space built in {time.time() - start_time:.2f} seconds")
+
+if not is_colab:
+ demo.queue(concurrency_count=2)
+demo.launch(debug=is_colab, enable_queue=True, share=is_colab)
diff --git a/diffusers/__init__.py b/diffusers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bda39cf3d49e4c0a4536a3439a390a5ce622eed
--- /dev/null
+++ b/diffusers/__init__.py
@@ -0,0 +1,123 @@
+from .utils import (
+ is_flax_available,
+ is_inflect_available,
+ is_onnx_available,
+ is_scipy_available,
+ is_torch_available,
+ is_transformers_available,
+ is_unidecode_available,
+)
+
+
+__version__ = "0.10.0.dev0"
+
+from .configuration_utils import ConfigMixin
+from .onnx_utils import OnnxRuntimeModel
+from .utils import logging
+
+
+if is_torch_available():
+ from .modeling_utils import ModelMixin
+ from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
+ from .optimization import (
+ get_constant_schedule,
+ get_constant_schedule_with_warmup,
+ get_cosine_schedule_with_warmup,
+ get_cosine_with_hard_restarts_schedule_with_warmup,
+ get_linear_schedule_with_warmup,
+ get_polynomial_decay_schedule_with_warmup,
+ get_scheduler,
+ )
+ from .pipeline_utils import DiffusionPipeline
+ from .pipelines import (
+ DanceDiffusionPipeline,
+ DDIMPipeline,
+ DDPMPipeline,
+ KarrasVePipeline,
+ LDMPipeline,
+ LDMSuperResolutionPipeline,
+ PNDMPipeline,
+ RePaintPipeline,
+ ScoreSdeVePipeline,
+ )
+ from .schedulers import (
+ DDIMScheduler,
+ DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ HeunDiscreteScheduler,
+ IPNDMScheduler,
+ KarrasVeScheduler,
+ KDPM2AncestralDiscreteScheduler,
+ KDPM2DiscreteScheduler,
+ PNDMScheduler,
+ RePaintScheduler,
+ SchedulerMixin,
+ ScoreSdeVeScheduler,
+ VQDiffusionScheduler,
+ )
+ from .training_utils import EMAModel
+else:
+ from .utils.dummy_pt_objects import * # noqa F403
+
+if is_torch_available() and is_scipy_available():
+ from .schedulers import LMSDiscreteScheduler
+else:
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
+
+if is_torch_available() and is_transformers_available():
+ from .pipelines import (
+ AltDiffusionImg2ImgPipeline,
+ AltDiffusionPipeline,
+ CycleDiffusionPipeline,
+ LDMTextToImagePipeline,
+ StableDiffusionImageVariationPipeline,
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionInpaintPipelineLegacy,
+ StableDiffusionPipeline,
+ StableDiffusionPipelineSafe,
+ StableDiffusionUpscalePipeline,
+ VersatileDiffusionDualGuidedPipeline,
+ VersatileDiffusionImageVariationPipeline,
+ VersatileDiffusionPipeline,
+ VersatileDiffusionTextToImagePipeline,
+ VQDiffusionPipeline,
+ )
+else:
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
+
+if is_torch_available() and is_transformers_available() and is_onnx_available():
+ from .pipelines import (
+ OnnxStableDiffusionImg2ImgPipeline,
+ OnnxStableDiffusionInpaintPipeline,
+ OnnxStableDiffusionInpaintPipelineLegacy,
+ OnnxStableDiffusionPipeline,
+ StableDiffusionOnnxPipeline,
+ )
+else:
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
+
+if is_flax_available():
+ from .modeling_flax_utils import FlaxModelMixin
+ from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
+ from .models.vae_flax import FlaxAutoencoderKL
+ from .pipeline_flax_utils import FlaxDiffusionPipeline
+ from .schedulers import (
+ FlaxDDIMScheduler,
+ FlaxDDPMScheduler,
+ FlaxDPMSolverMultistepScheduler,
+ FlaxKarrasVeScheduler,
+ FlaxLMSDiscreteScheduler,
+ FlaxPNDMScheduler,
+ FlaxSchedulerMixin,
+ FlaxScoreSdeVeScheduler,
+ )
+else:
+ from .utils.dummy_flax_objects import * # noqa F403
+
+if is_flax_available() and is_transformers_available():
+ from .pipelines import FlaxStableDiffusionPipeline
+else:
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
diff --git a/diffusers/commands/__init__.py b/diffusers/commands/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..902bd46cedc6f2df785c1dc5d2e6bd8ef7c69ca6
--- /dev/null
+++ b/diffusers/commands/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from argparse import ArgumentParser
+
+
+class BaseDiffusersCLICommand(ABC):
+ @staticmethod
+ @abstractmethod
+ def register_subcommand(parser: ArgumentParser):
+ raise NotImplementedError()
+
+ @abstractmethod
+ def run(self):
+ raise NotImplementedError()
diff --git a/diffusers/commands/diffusers_cli.py b/diffusers/commands/diffusers_cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..30084e55ba4eeec79c87a99eae3e60a6233dc556
--- /dev/null
+++ b/diffusers/commands/diffusers_cli.py
@@ -0,0 +1,41 @@
+#!/usr/bin/env python
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from argparse import ArgumentParser
+
+from .env import EnvironmentCommand
+
+
+def main():
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []")
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
+
+ # Register commands
+ EnvironmentCommand.register_subcommand(commands_parser)
+
+ # Let's go
+ args = parser.parse_args()
+
+ if not hasattr(args, "func"):
+ parser.print_help()
+ exit(1)
+
+ # Run
+ service = args.func(args)
+ service.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/commands/env.py b/diffusers/commands/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..81a878bff6688d3c510b53c60ac9d0e51e4aebcc
--- /dev/null
+++ b/diffusers/commands/env.py
@@ -0,0 +1,70 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import platform
+from argparse import ArgumentParser
+
+import huggingface_hub
+
+from .. import __version__ as version
+from ..utils import is_torch_available, is_transformers_available
+from . import BaseDiffusersCLICommand
+
+
+def info_command_factory(_):
+ return EnvironmentCommand()
+
+
+class EnvironmentCommand(BaseDiffusersCLICommand):
+ @staticmethod
+ def register_subcommand(parser: ArgumentParser):
+ download_parser = parser.add_parser("env")
+ download_parser.set_defaults(func=info_command_factory)
+
+ def run(self):
+ hub_version = huggingface_hub.__version__
+
+ pt_version = "not installed"
+ pt_cuda_available = "NA"
+ if is_torch_available():
+ import torch
+
+ pt_version = torch.__version__
+ pt_cuda_available = torch.cuda.is_available()
+
+ transformers_version = "not installed"
+ if is_transformers_available:
+ import transformers
+
+ transformers_version = transformers.__version__
+
+ info = {
+ "`diffusers` version": version,
+ "Platform": platform.platform(),
+ "Python version": platform.python_version(),
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
+ "Huggingface_hub version": hub_version,
+ "Transformers version": transformers_version,
+ "Using GPU in script?": "",
+ "Using distributed or parallel set-up in script?": "",
+ }
+
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
+ print(self.format_dict(info))
+
+ return info
+
+ @staticmethod
+ def format_dict(d):
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diff --git a/diffusers/configuration_utils.py b/diffusers/configuration_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecf23010c3c15f0fd7608888cb22f19e0045daf4
--- /dev/null
+++ b/diffusers/configuration_utils.py
@@ -0,0 +1,613 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" ConfigMixin base class and utilities."""
+import dataclasses
+import functools
+import importlib
+import inspect
+import json
+import os
+import re
+from collections import OrderedDict
+from typing import Any, Dict, Tuple, Union
+
+import numpy as np
+
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from requests import HTTPError
+
+from . import __version__
+from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
+
+
+logger = logging.get_logger(__name__)
+
+_re_configuration_file = re.compile(r"config\.(.*)\.json")
+
+
+class FrozenDict(OrderedDict):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for key, value in self.items():
+ setattr(self, key, value)
+
+ self.__frozen = True
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __setattr__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setitem__(name, value)
+
+
+class ConfigMixin:
+ r"""
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
+ - [`~ConfigMixin.from_config`]
+ - [`~ConfigMixin.save_config`]
+
+ Class attributes:
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
+ overridden by subclass).
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
+ subclass).
+ """
+ config_name = None
+ ignore_for_config = []
+ has_compatibles = False
+
+ _deprecated_kwargs = []
+
+ def register_to_config(self, **kwargs):
+ if self.config_name is None:
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
+ # Special case for `kwargs` used in deprecation warning added to schedulers
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
+ # or solve in a more general way.
+ kwargs.pop("kwargs", None)
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ if not hasattr(self, "_internal_dict"):
+ internal_dict = kwargs
+ else:
+ previous_dict = dict(self._internal_dict)
+ internal_dict = {**self._internal_dict, **kwargs}
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
+
+ self._internal_dict = FrozenDict(internal_dict)
+
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~ConfigMixin.from_config`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we save using the predefined names, we can load using `from_config`
+ output_config_file = os.path.join(save_directory, self.config_name)
+
+ self.to_json_file(output_config_file)
+ logger.info(f"Configuration saved in {output_config_file}")
+
+ @classmethod
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
+ r"""
+ Instantiate a Python class from a config dictionary
+
+ Parameters:
+ config (`Dict[str, Any]`):
+ A config dictionary from which the Python class will be instantiated. Make sure to only load
+ configuration files of compatible classes.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the Python class.
+ `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
+ overwrite same named arguments of `config`.
+
+ Examples:
+
+ ```python
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
+
+ >>> # Download scheduler from huggingface.co and cache.
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
+
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
+
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
+ ```
+ """
+ # <===== TO BE REMOVED WITH DEPRECATION
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
+ if "pretrained_model_name_or_path" in kwargs:
+ config = kwargs.pop("pretrained_model_name_or_path")
+
+ if config is None:
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
+ # ======>
+
+ if not isinstance(config, dict):
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
+ if "Scheduler" in cls.__name__:
+ deprecation_message += (
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
+ " be removed in v1.0.0."
+ )
+ elif "Model" in cls.__name__:
+ deprecation_message += (
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
+ " instead. This functionality will be removed in v1.0.0."
+ )
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
+
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
+
+ # Allow dtype to be specified on initialization
+ if "dtype" in unused_kwargs:
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
+
+ # add possible deprecated kwargs
+ for deprecated_kwarg in cls._deprecated_kwargs:
+ if deprecated_kwarg in unused_kwargs:
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
+
+ # Return model and optionally state and/or unused_kwargs
+ model = cls(**init_dict)
+
+ # make sure to also save config parameters that might be used for compatible classes
+ model.register_to_config(**hidden_dict)
+
+ # add hidden kwargs of compatible classes to unused_kwargs
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
+
+ if return_unused_kwargs:
+ return (model, unused_kwargs)
+ else:
+ return model
+
+ @classmethod
+ def get_config_dict(cls, *args, **kwargs):
+ deprecation_message = (
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
+ " removed in version v1.0.0"
+ )
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
+ return cls.load_config(*args, **kwargs)
+
+ @classmethod
+ def load_config(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ r"""
+ Instantiate a Python class from a config dictionary
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ _ = kwargs.pop("mirror", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {"file_type": "config"}
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ if cls.config_name is None:
+ raise ValueError(
+ "`self.config_name` is not defined. Note that one should not load a config from "
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
+ )
+
+ if os.path.isfile(pretrained_model_name_or_path):
+ config_file = pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
+ # Load from a PyTorch checkpoint
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ ):
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ config_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=cls.config_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
+ " login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
+ " this model name. Check the model page at"
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
+ " run the library in offline mode at"
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a {cls.config_name} file"
+ )
+
+ try:
+ # Load config dict
+ config_dict = cls._dict_from_json_file(config_file)
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
+
+ if return_unused_kwargs:
+ return config_dict, kwargs
+
+ return config_dict
+
+ @staticmethod
+ def _get_init_keys(cls):
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
+
+ @classmethod
+ def extract_init_dict(cls, config_dict, **kwargs):
+ # 0. Copy origin config dict
+ original_dict = {k: v for k, v in config_dict.items()}
+
+ # 1. Retrieve expected config attributes from __init__ signature
+ expected_keys = cls._get_init_keys(cls)
+ expected_keys.remove("self")
+ # remove general kwargs if present in dict
+ if "kwargs" in expected_keys:
+ expected_keys.remove("kwargs")
+ # remove flax internal keys
+ if hasattr(cls, "_flax_internal_args"):
+ for arg in cls._flax_internal_args:
+ expected_keys.remove(arg)
+
+ # 2. Remove attributes that cannot be expected from expected config attributes
+ # remove keys to be ignored
+ if len(cls.ignore_for_config) > 0:
+ expected_keys = expected_keys - set(cls.ignore_for_config)
+
+ # load diffusers library to import compatible and original scheduler
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
+
+ if cls.has_compatibles:
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
+ else:
+ compatible_classes = []
+
+ expected_keys_comp_cls = set()
+ for c in compatible_classes:
+ expected_keys_c = cls._get_init_keys(c)
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
+
+ # remove attributes from orig class that cannot be expected
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
+ if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
+ orig_cls = getattr(diffusers_library, orig_cls_name)
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
+
+ # remove private attributes
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
+
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
+ init_dict = {}
+ for key in expected_keys:
+ # if config param is passed to kwarg and is present in config dict
+ # it should overwrite existing config dict key
+ if key in kwargs and key in config_dict:
+ config_dict[key] = kwargs.pop(key)
+
+ if key in kwargs:
+ # overwrite key
+ init_dict[key] = kwargs.pop(key)
+ elif key in config_dict:
+ # use value from config dict
+ init_dict[key] = config_dict.pop(key)
+
+ # 4. Give nice warning if unexpected values have been passed
+ if len(config_dict) > 0:
+ logger.warning(
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
+ "but are not expected and will be ignored. Please verify your "
+ f"{cls.config_name} configuration file."
+ )
+
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
+ passed_keys = set(init_dict.keys())
+ if len(expected_keys - passed_keys) > 0:
+ logger.info(
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
+ )
+
+ # 6. Define unused keyword arguments
+ unused_kwargs = {**config_dict, **kwargs}
+
+ # 7. Define "hidden" config parameters that were saved for compatible classes
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
+
+ return init_dict, unused_kwargs, hidden_config_dict
+
+ @classmethod
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @property
+ def config(self) -> Dict[str, Any]:
+ """
+ Returns the config of the class as a frozen dictionary
+
+ Returns:
+ `Dict[str, Any]`: Config of the class.
+ """
+ return self._internal_dict
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
+ config_dict["_class_name"] = self.__class__.__name__
+ config_dict["_diffusers_version"] = __version__
+
+ def to_json_saveable(value):
+ if isinstance(value, np.ndarray):
+ value = value.tolist()
+ return value
+
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+
+def register_to_config(init):
+ r"""
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
+
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
+ """
+
+ @functools.wraps(init)
+ def inner_init(self, *args, **kwargs):
+ # Ignore private kwargs in the init.
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
+ if not isinstance(self, ConfigMixin):
+ raise RuntimeError(
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
+ "not inherit from `ConfigMixin`."
+ )
+
+ ignore = getattr(self, "ignore_for_config", [])
+ # Get positional arguments aligned with kwargs
+ new_kwargs = {}
+ signature = inspect.signature(init)
+ parameters = {
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
+ }
+ for arg, name in zip(args, parameters.keys()):
+ new_kwargs[name] = arg
+
+ # Then add all kwargs
+ new_kwargs.update(
+ {
+ k: init_kwargs.get(k, default)
+ for k, default in parameters.items()
+ if k not in ignore and k not in new_kwargs
+ }
+ )
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
+ getattr(self, "register_to_config")(**new_kwargs)
+ init(self, *args, **init_kwargs)
+
+ return inner_init
+
+
+def flax_register_to_config(cls):
+ original_init = cls.__init__
+
+ @functools.wraps(original_init)
+ def init(self, *args, **kwargs):
+ if not isinstance(self, ConfigMixin):
+ raise RuntimeError(
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
+ "not inherit from `ConfigMixin`."
+ )
+
+ # Ignore private kwargs in the init. Retrieve all passed attributes
+ init_kwargs = {k: v for k, v in kwargs.items()}
+
+ # Retrieve default values
+ fields = dataclasses.fields(self)
+ default_kwargs = {}
+ for field in fields:
+ # ignore flax specific attributes
+ if field.name in self._flax_internal_args:
+ continue
+ if type(field.default) == dataclasses._MISSING_TYPE:
+ default_kwargs[field.name] = None
+ else:
+ default_kwargs[field.name] = getattr(self, field.name)
+
+ # Make sure init_kwargs override default kwargs
+ new_kwargs = {**default_kwargs, **init_kwargs}
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
+ if "dtype" in new_kwargs:
+ new_kwargs.pop("dtype")
+
+ # Get positional arguments aligned with kwargs
+ for i, arg in enumerate(args):
+ name = fields[i].name
+ new_kwargs[name] = arg
+
+ getattr(self, "register_to_config")(**new_kwargs)
+ original_init(self, *args, **kwargs)
+
+ cls.__init__ = init
+ return cls
diff --git a/diffusers/dependency_versions_check.py b/diffusers/dependency_versions_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf863222a52fd60a15a95be0fbd6391acd3ba6d
--- /dev/null
+++ b/diffusers/dependency_versions_check.py
@@ -0,0 +1,47 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from .dependency_versions_table import deps
+from .utils.versions import require_version, require_version_core
+
+
+# define which module versions we always want to check at run time
+# (usually the ones defined in `install_requires` in setup.py)
+#
+# order specific notes:
+# - tqdm must be checked before tokenizers
+
+pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
+if sys.version_info < (3, 7):
+ pkgs_to_check_at_runtime.append("dataclasses")
+if sys.version_info < (3, 8):
+ pkgs_to_check_at_runtime.append("importlib_metadata")
+
+for pkg in pkgs_to_check_at_runtime:
+ if pkg in deps:
+ if pkg == "tokenizers":
+ # must be loaded here, or else tqdm check may fail
+ from .utils import is_tokenizers_available
+
+ if not is_tokenizers_available():
+ continue # not required, check version only if installed
+
+ require_version_core(deps[pkg])
+ else:
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
+
+
+def dep_version_check(pkg, hint=None):
+ require_version(deps[pkg], hint)
diff --git a/diffusers/dependency_versions_table.py b/diffusers/dependency_versions_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fd6bfa1fa17aaeab7adf089e605380ff508a725
--- /dev/null
+++ b/diffusers/dependency_versions_table.py
@@ -0,0 +1,33 @@
+# THIS FILE HAS BEEN AUTOGENERATED. To update:
+# 1. modify the `_deps` dict in setup.py
+# 2. run `make deps_table_update``
+deps = {
+ "Pillow": "Pillow",
+ "accelerate": "accelerate>=0.11.0",
+ "black": "black==22.8",
+ "datasets": "datasets",
+ "filelock": "filelock",
+ "flake8": "flake8>=3.8.3",
+ "flax": "flax>=0.4.1",
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
+ "huggingface-hub": "huggingface-hub>=0.10.0",
+ "importlib_metadata": "importlib_metadata",
+ "isort": "isort>=5.5.4",
+ "jax": "jax>=0.2.8,!=0.3.2",
+ "jaxlib": "jaxlib>=0.1.65",
+ "modelcards": "modelcards>=0.1.4",
+ "numpy": "numpy",
+ "parameterized": "parameterized",
+ "pytest": "pytest",
+ "pytest-timeout": "pytest-timeout",
+ "pytest-xdist": "pytest-xdist",
+ "safetensors": "safetensors",
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
+ "scipy": "scipy",
+ "regex": "regex!=2019.12.17",
+ "requests": "requests",
+ "tensorboard": "tensorboard",
+ "torch": "torch>=1.4",
+ "torchvision": "torchvision",
+ "transformers": "transformers>=4.21.0",
+}
diff --git a/diffusers/dynamic_modules_utils.py b/diffusers/dynamic_modules_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..31f3bed2ecf9794b1bf9dab265af32f98dbb7afc
--- /dev/null
+++ b/diffusers/dynamic_modules_utils.py
@@ -0,0 +1,428 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities to dynamically load objects from the Hub."""
+
+import importlib
+import inspect
+import os
+import re
+import shutil
+import sys
+from pathlib import Path
+from typing import Dict, Optional, Union
+
+from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
+
+from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
+
+
+COMMUNITY_PIPELINES_URL = (
+ "https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/{pipeline}.py"
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def init_hf_modules():
+ """
+ Creates the cache directory for modules with an init, and adds it to the Python path.
+ """
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
+ if HF_MODULES_CACHE in sys.path:
+ return
+
+ sys.path.append(HF_MODULES_CACHE)
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+
+
+def create_dynamic_module(name: Union[str, os.PathLike]):
+ """
+ Creates a dynamic module in the cache directory for modules.
+ """
+ init_hf_modules()
+ dynamic_module_path = Path(HF_MODULES_CACHE) / name
+ # If the parent module does not exist yet, recursively create it.
+ if not dynamic_module_path.parent.exists():
+ create_dynamic_module(dynamic_module_path.parent)
+ os.makedirs(dynamic_module_path, exist_ok=True)
+ init_path = dynamic_module_path / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+
+
+def get_relative_imports(module_file):
+ """
+ Get the list of modules that are relatively imported in a module file.
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+ """
+ with open(module_file, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import .xxx`
+ relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from .xxx import yyy`
+ relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
+ # Unique-ify
+ return list(set(relative_imports))
+
+
+def get_relative_import_files(module_file):
+ """
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
+ imports (if a imports b and b imports c, it will return module files for b and c).
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+ """
+ no_change = False
+ files_to_check = [module_file]
+ all_relative_imports = []
+
+ # Let's recurse through all relative imports
+ while not no_change:
+ new_imports = []
+ for f in files_to_check:
+ new_imports.extend(get_relative_imports(f))
+
+ module_path = Path(module_file).parent
+ new_import_files = [str(module_path / m) for m in new_imports]
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
+ files_to_check = [f"{f}.py" for f in new_import_files]
+
+ no_change = len(new_import_files) == 0
+ all_relative_imports.extend(files_to_check)
+
+ return all_relative_imports
+
+
+def check_imports(filename):
+ """
+ Check if the current Python environment contains all the libraries that are imported in a file.
+ """
+ with open(filename, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import xxx`
+ imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from xxx import yyy`
+ imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
+ # Only keep the top-level module
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
+
+ # Unique-ify and test we got them all
+ imports = list(set(imports))
+ missing_packages = []
+ for imp in imports:
+ try:
+ importlib.import_module(imp)
+ except ImportError:
+ missing_packages.append(imp)
+
+ if len(missing_packages) > 0:
+ raise ImportError(
+ "This modeling file requires the following packages that were not found in your environment: "
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
+ )
+
+ return get_relative_imports(filename)
+
+
+def get_class_in_module(class_name, module_path):
+ """
+ Import a module on the cache directory for modules and extract a class from it.
+ """
+ module_path = module_path.replace(os.path.sep, ".")
+ module = importlib.import_module(module_path)
+
+ if class_name is None:
+ return find_pipeline_class(module)
+ return getattr(module, class_name)
+
+
+def find_pipeline_class(loaded_module):
+ """
+ Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class
+ inheriting from `DiffusionPipeline`.
+ """
+ from .pipeline_utils import DiffusionPipeline
+
+ cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))
+
+ pipeline_class = None
+ for cls_name, cls in cls_members.items():
+ if (
+ cls_name != DiffusionPipeline.__name__
+ and issubclass(cls, DiffusionPipeline)
+ and cls.__module__.split(".")[0] != "diffusers"
+ ):
+ if pipeline_class is not None:
+ raise ValueError(
+ f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:"
+ f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in"
+ f" {loaded_module}."
+ )
+ pipeline_class = cls
+
+ return pipeline_class
+
+
+def get_cached_module_file(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+):
+ """
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
+ Transformers module.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
+ or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+ Returns:
+ `str`: The path to the module inside the cache.
+ """
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
+
+ if os.path.isfile(module_file_or_url):
+ resolved_module_file = module_file_or_url
+ submodule = "local"
+ elif pretrained_model_name_or_path.count("/") == 0:
+ # community pipeline on GitHub
+ github_url = COMMUNITY_PIPELINES_URL.format(pipeline=pretrained_model_name_or_path)
+ try:
+ resolved_module_file = cached_download(
+ github_url,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=False,
+ )
+ submodule = "git"
+ module_file = pretrained_model_name_or_path + ".py"
+ except EnvironmentError:
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
+ raise
+ else:
+ try:
+ # Load from URL or cache if already cached
+ resolved_module_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ module_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ )
+ submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
+ except EnvironmentError:
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
+ raise
+
+ # Check we have all the requirements in our environment
+ modules_needed = check_imports(resolved_module_file)
+
+ # Now we move the module inside our cached dynamic modules.
+ full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
+ create_dynamic_module(full_submodule)
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
+ if submodule == "local" or submodule == "git":
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
+ # that hash, to only copy when there is a modification but it seems overkill for now).
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ for module_needed in modules_needed:
+ module_needed = f"{module_needed}.py"
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
+ else:
+ # Get the commit hash
+ # TODO: we will get this info in the etag soon, so retrieve it from there and not here.
+ if isinstance(use_auth_token, str):
+ token = use_auth_token
+ elif use_auth_token is True:
+ token = HfFolder.get_token()
+ else:
+ token = None
+
+ commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
+
+ # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
+ # benefit of versioning.
+ submodule_path = submodule_path / commit_hash
+ full_submodule = full_submodule + os.path.sep + commit_hash
+ create_dynamic_module(full_submodule)
+
+ if not (submodule_path / module_file).exists():
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ # Make sure we also have every file with relative
+ for module_needed in modules_needed:
+ if not (submodule_path / module_needed).exists():
+ get_cached_module_file(
+ pretrained_model_name_or_path,
+ f"{module_needed}.py",
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
+ return os.path.join(full_submodule, module_file)
+
+
+def get_class_from_dynamic_module(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ class_name: Optional[str] = None,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ **kwargs,
+):
+ """
+ Extracts a class from a module file, present in the local folder or repository of a model.
+
+
+
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
+ therefore only be called on trusted repos.
+
+
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ class_name (`str`):
+ The name of the class to import in the module.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
+ or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+ Returns:
+ `type`: The class, dynamically imported from the module.
+
+ Examples:
+
+ ```python
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
+ # module.
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
+ ```"""
+ # And lastly we get the class inside our newly created module
+ final_module = get_cached_module_file(
+ pretrained_model_name_or_path,
+ module_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
+ return get_class_in_module(class_name, final_module.replace(".py", ""))
diff --git a/diffusers/experimental/README.md b/diffusers/experimental/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..81a9de81c73728ea41eb6e8617a5429c3c9645ff
--- /dev/null
+++ b/diffusers/experimental/README.md
@@ -0,0 +1,5 @@
+# 🧨 Diffusers Experimental
+
+We are adding experimental code to support novel applications and usages of the Diffusers library.
+Currently, the following experiments are supported:
+* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
\ No newline at end of file
diff --git a/diffusers/experimental/__init__.py b/diffusers/experimental/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc8155403016dfd8ad7fb78d246f9da9098ac50
--- /dev/null
+++ b/diffusers/experimental/__init__.py
@@ -0,0 +1 @@
+from .rl import ValueGuidedRLPipeline
diff --git a/diffusers/experimental/rl/__init__.py b/diffusers/experimental/rl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b338d3173e12d478b6b6d6fd0e50650a0ab5a4c
--- /dev/null
+++ b/diffusers/experimental/rl/__init__.py
@@ -0,0 +1 @@
+from .value_guided_sampling import ValueGuidedRLPipeline
diff --git a/diffusers/experimental/rl/value_guided_sampling.py b/diffusers/experimental/rl/value_guided_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dd935f54d608f45c8ae69eda5a571f1bf65084b
--- /dev/null
+++ b/diffusers/experimental/rl/value_guided_sampling.py
@@ -0,0 +1,130 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+
+import tqdm
+
+from ...models.unet_1d import UNet1DModel
+from ...pipeline_utils import DiffusionPipeline
+from ...utils.dummy_pt_objects import DDPMScheduler
+
+
+class ValueGuidedRLPipeline(DiffusionPipeline):
+ def __init__(
+ self,
+ value_function: UNet1DModel,
+ unet: UNet1DModel,
+ scheduler: DDPMScheduler,
+ env,
+ ):
+ super().__init__()
+ self.value_function = value_function
+ self.unet = unet
+ self.scheduler = scheduler
+ self.env = env
+ self.data = env.get_dataset()
+ self.means = dict()
+ for key in self.data.keys():
+ try:
+ self.means[key] = self.data[key].mean()
+ except:
+ pass
+ self.stds = dict()
+ for key in self.data.keys():
+ try:
+ self.stds[key] = self.data[key].std()
+ except:
+ pass
+ self.state_dim = env.observation_space.shape[0]
+ self.action_dim = env.action_space.shape[0]
+
+ def normalize(self, x_in, key):
+ return (x_in - self.means[key]) / self.stds[key]
+
+ def de_normalize(self, x_in, key):
+ return x_in * self.stds[key] + self.means[key]
+
+ def to_torch(self, x_in):
+ if type(x_in) is dict:
+ return {k: self.to_torch(v) for k, v in x_in.items()}
+ elif torch.is_tensor(x_in):
+ return x_in.to(self.unet.device)
+ return torch.tensor(x_in, device=self.unet.device)
+
+ def reset_x0(self, x_in, cond, act_dim):
+ for key, val in cond.items():
+ x_in[:, key, act_dim:] = val.clone()
+ return x_in
+
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
+ batch_size = x.shape[0]
+ y = None
+ for i in tqdm.tqdm(self.scheduler.timesteps):
+ # create batch of timesteps to pass into model
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
+ for _ in range(n_guide_steps):
+ with torch.enable_grad():
+ x.requires_grad_()
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
+ grad = torch.autograd.grad([y.sum()], [x])[0]
+
+ posterior_variance = self.scheduler._get_variance(i)
+ model_std = torch.exp(0.5 * posterior_variance)
+ grad = model_std * grad
+ grad[timesteps < 2] = 0
+ x = x.detach()
+ x = x + scale * grad
+ x = self.reset_x0(x, conditions, self.action_dim)
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
+ # TODO: set prediction_type when instantiating the model
+ x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
+
+ # apply conditions to the trajectory
+ x = self.reset_x0(x, conditions, self.action_dim)
+ x = self.to_torch(x)
+ return x, y
+
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
+ # normalize the observations and create batch dimension
+ obs = self.normalize(obs, "observations")
+ obs = obs[None].repeat(batch_size, axis=0)
+
+ conditions = {0: self.to_torch(obs)}
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
+
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
+ x1 = torch.randn(shape, device=self.unet.device)
+ x = self.reset_x0(x1, conditions, self.action_dim)
+ x = self.to_torch(x)
+
+ # run the diffusion process
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
+
+ # sort output trajectories by value
+ sorted_idx = y.argsort(0, descending=True).squeeze()
+ sorted_values = x[sorted_idx]
+ actions = sorted_values[:, :, : self.action_dim]
+ actions = actions.detach().cpu().numpy()
+ denorm_actions = self.de_normalize(actions, key="actions")
+
+ # select the action with the highest value
+ if y is not None:
+ selected_index = 0
+ else:
+ # if we didn't run value guiding, select a random action
+ selected_index = np.random.randint(0, batch_size)
+ denorm_actions = denorm_actions[selected_index, 0]
+ return denorm_actions
diff --git a/diffusers/hub_utils.py b/diffusers/hub_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1772d8f70bf8b5cf73c18bafe107d55f81c0f27
--- /dev/null
+++ b/diffusers/hub_utils.py
@@ -0,0 +1,130 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import sys
+from pathlib import Path
+from typing import Dict, Optional, Union
+from uuid import uuid4
+
+from huggingface_hub import HfFolder, whoami
+
+from . import __version__
+from .utils import ENV_VARS_TRUE_VALUES, logging
+from .utils.import_utils import (
+ _flax_version,
+ _jax_version,
+ _onnxruntime_version,
+ _torch_version,
+ is_flax_available,
+ is_modelcards_available,
+ is_onnx_available,
+ is_torch_available,
+)
+
+
+if is_modelcards_available():
+ from modelcards import CardData, ModelCard
+
+
+logger = logging.get_logger(__name__)
+
+
+MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
+SESSION_ID = uuid4().hex
+DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
+
+
+def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
+ """
+ Formats a user-agent string with basic info about a request.
+ """
+ ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
+ if DISABLE_TELEMETRY:
+ return ua + "; telemetry/off"
+ if is_torch_available():
+ ua += f"; torch/{_torch_version}"
+ if is_flax_available():
+ ua += f"; jax/{_jax_version}"
+ ua += f"; flax/{_flax_version}"
+ if is_onnx_available():
+ ua += f"; onnxruntime/{_onnxruntime_version}"
+ # CI will set this value to True
+ if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
+ ua += "; is_ci/true"
+ if isinstance(user_agent, dict):
+ ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
+ elif isinstance(user_agent, str):
+ ua += "; " + user_agent
+ return ua
+
+
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
+def create_model_card(args, model_name):
+ if not is_modelcards_available:
+ raise ValueError(
+ "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
+ " install the package with `pip install modelcards`."
+ )
+
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
+ return
+
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
+ repo_name = get_full_repo_name(model_name, token=hub_token)
+
+ model_card = ModelCard.from_template(
+ card_data=CardData( # Card metadata object that will be converted to YAML block
+ language="en",
+ license="apache-2.0",
+ library_name="diffusers",
+ tags=[],
+ datasets=args.dataset_name,
+ metrics=[],
+ ),
+ template_path=MODEL_CARD_TEMPLATE_PATH,
+ model_name=model_name,
+ repo_name=repo_name,
+ dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
+ learning_rate=args.learning_rate,
+ train_batch_size=args.train_batch_size,
+ eval_batch_size=args.eval_batch_size,
+ gradient_accumulation_steps=args.gradient_accumulation_steps
+ if hasattr(args, "gradient_accumulation_steps")
+ else None,
+ adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
+ adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
+ adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
+ adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
+ lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
+ lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
+ ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
+ ema_power=args.ema_power if hasattr(args, "ema_power") else None,
+ ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
+ mixed_precision=args.mixed_precision,
+ )
+
+ card_path = os.path.join(args.output_dir, "README.md")
+ model_card.save(card_path)
diff --git a/diffusers/modeling_flax_pytorch_utils.py b/diffusers/modeling_flax_pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c7a5de2ad6e9735294286955c24b92c957aba5b
--- /dev/null
+++ b/diffusers/modeling_flax_pytorch_utils.py
@@ -0,0 +1,117 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch - Flax general utilities."""
+import re
+
+import jax.numpy as jnp
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax.random import PRNGKey
+
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def rename_key(key):
+ regex = r"\w+[.]\d+"
+ pats = re.findall(regex, key)
+ for pat in pats:
+ key = key.replace(pat, "_".join(pat.split(".")))
+ return key
+
+
+#####################
+# PyTorch => Flax #
+#####################
+
+# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
+# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
+def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
+ """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
+
+ # conv norm or layer norm
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
+ if (
+ any("norm" in str_ for str_ in pt_tuple_key)
+ and (pt_tuple_key[-1] == "bias")
+ and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
+ and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
+ ):
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
+ return renamed_pt_tuple_key, pt_tensor
+ elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
+ return renamed_pt_tuple_key, pt_tensor
+
+ # embedding
+ if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
+ pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
+ return renamed_pt_tuple_key, pt_tensor
+
+ # conv layer
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
+ if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
+ pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
+ return renamed_pt_tuple_key, pt_tensor
+
+ # linear layer
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
+ if pt_tuple_key[-1] == "weight":
+ pt_tensor = pt_tensor.T
+ return renamed_pt_tuple_key, pt_tensor
+
+ # old PyTorch layer norm weight
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
+ if pt_tuple_key[-1] == "gamma":
+ return renamed_pt_tuple_key, pt_tensor
+
+ # old PyTorch layer norm bias
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
+ if pt_tuple_key[-1] == "beta":
+ return renamed_pt_tuple_key, pt_tensor
+
+ return pt_tuple_key, pt_tensor
+
+
+def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
+ # Step 1: Convert pytorch tensor to numpy
+ pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
+
+ # Step 2: Since the model is stateless, get random Flax params
+ random_flax_params = flax_model.init_weights(PRNGKey(init_key))
+
+ random_flax_state_dict = flatten_dict(random_flax_params)
+ flax_state_dict = {}
+
+ # Need to change some parameters name to match Flax names
+ for pt_key, pt_tensor in pt_state_dict.items():
+ renamed_pt_key = rename_key(pt_key)
+ pt_tuple_key = tuple(renamed_pt_key.split("."))
+
+ # Correctly rename weight parameters
+ flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
+
+ if flax_key in random_flax_state_dict:
+ if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
+ raise ValueError(
+ f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
+ f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
+ )
+
+ # also add unexpected weight so that warning is thrown
+ flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
+
+ return unflatten_dict(flax_state_dict)
diff --git a/diffusers/modeling_flax_utils.py b/diffusers/modeling_flax_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..857fdd1b0b33556126693940f5d2dddb061ce916
--- /dev/null
+++ b/diffusers/modeling_flax_utils.py
@@ -0,0 +1,526 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from pickle import UnpicklingError
+from typing import Any, Dict, Union
+
+import jax
+import jax.numpy as jnp
+import msgpack.exceptions
+from flax.core.frozen_dict import FrozenDict, unfreeze
+from flax.serialization import from_bytes, to_bytes
+from flax.traverse_util import flatten_dict, unflatten_dict
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from requests import HTTPError
+
+from . import __version__, is_torch_available
+from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
+from .utils import (
+ CONFIG_NAME,
+ DIFFUSERS_CACHE,
+ FLAX_WEIGHTS_NAME,
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
+ WEIGHTS_NAME,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+class FlaxModelMixin:
+ r"""
+ Base class for all flax models.
+
+ [`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
+ downloading and saving models.
+ """
+ config_name = CONFIG_NAME
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
+ _flax_internal_args = ["name", "parent", "dtype"]
+
+ @classmethod
+ def _from_config(cls, config, **kwargs):
+ """
+ All context managers that the model should be initialized under go here.
+ """
+ return cls(config, **kwargs)
+
+ def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
+ """
+ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
+ """
+
+ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
+ def conditional_cast(param):
+ if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
+ param = param.astype(dtype)
+ return param
+
+ if mask is None:
+ return jax.tree_map(conditional_cast, params)
+
+ flat_params = flatten_dict(params)
+ flat_mask, _ = jax.tree_flatten(mask)
+
+ for masked, key in zip(flat_mask, flat_params.keys()):
+ if masked:
+ param = flat_params[key]
+ flat_params[key] = conditional_cast(param)
+
+ return unflatten_dict(flat_params)
+
+ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
+ r"""
+ Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
+ the `params` in place.
+
+ This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
+ half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
+
+ Arguments:
+ params (`Union[Dict, FrozenDict]`):
+ A `PyTree` of model parameters.
+ mask (`Union[Dict, FrozenDict]`):
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
+ you want to cast, and should be `False` for those you want to skip.
+
+ Examples:
+
+ ```python
+ >>> from diffusers import FlaxUNet2DConditionModel
+
+ >>> # load model
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
+ >>> params = model.to_bf16(params)
+ >>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
+ >>> # then pass the mask as follows
+ >>> from flax import traverse_util
+
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> flat_params = traverse_util.flatten_dict(params)
+ >>> mask = {
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
+ ... for path in flat_params
+ ... }
+ >>> mask = traverse_util.unflatten_dict(mask)
+ >>> params = model.to_bf16(params, mask)
+ ```"""
+ return self._cast_floating_to(params, jnp.bfloat16, mask)
+
+ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
+ r"""
+ Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
+ model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
+
+ Arguments:
+ params (`Union[Dict, FrozenDict]`):
+ A `PyTree` of model parameters.
+ mask (`Union[Dict, FrozenDict]`):
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
+ you want to cast, and should be `False` for those you want to skip
+
+ Examples:
+
+ ```python
+ >>> from diffusers import FlaxUNet2DConditionModel
+
+ >>> # Download model and configuration from huggingface.co
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> # By default, the model params will be in fp32, to illustrate the use of this method,
+ >>> # we'll first cast to fp16 and back to fp32
+ >>> params = model.to_f16(params)
+ >>> # now cast back to fp32
+ >>> params = model.to_fp32(params)
+ ```"""
+ return self._cast_floating_to(params, jnp.float32, mask)
+
+ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
+ r"""
+ Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
+ `params` in place.
+
+ This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
+ half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
+
+ Arguments:
+ params (`Union[Dict, FrozenDict]`):
+ A `PyTree` of model parameters.
+ mask (`Union[Dict, FrozenDict]`):
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
+ you want to cast, and should be `False` for those you want to skip
+
+ Examples:
+
+ ```python
+ >>> from diffusers import FlaxUNet2DConditionModel
+
+ >>> # load model
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> # By default, the model params will be in fp32, to cast these to float16
+ >>> params = model.to_fp16(params)
+ >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
+ >>> # then pass the mask as follows
+ >>> from flax import traverse_util
+
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> flat_params = traverse_util.flatten_dict(params)
+ >>> mask = {
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
+ ... for path in flat_params
+ ... }
+ >>> mask = traverse_util.unflatten_dict(mask)
+ >>> params = model.to_fp16(params, mask)
+ ```"""
+ return self._cast_floating_to(params, jnp.float16, mask)
+
+ def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
+ raise NotImplementedError(f"init_weights method has to be implemented for {self}")
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ dtype: jnp.dtype = jnp.float32,
+ *model_args,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a pretrained flax model from a pre-trained model configuration.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids are namespaced under a user or organization name, like
+ `runwayml/stable-diffusion-v1-5`.
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
+ e.g., `./my_model_directory/`.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and
+ [`~ModelMixin.to_bf16`].
+ model_args (sequence of positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ from_pt (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a PyTorch checkpoint save file.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
+ automatically loaded:
+
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
+ already been done)
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
+ initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to
+ a configuration attribute will be used to override said attribute with the supplied `kwargs`
+ value. Remaining keys that do not correspond to any configuration attribute will be passed to the
+ underlying model's `__init__` function.
+
+ Examples:
+
+ ```python
+ >>> from diffusers import FlaxUNet2DConditionModel
+
+ >>> # Download model and configuration from huggingface.co and cache.
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
+ ```"""
+ config = kwargs.pop("config", None)
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ from_pt = kwargs.pop("from_pt", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "model",
+ "framework": "flax",
+ }
+
+ # Load config if we don't provide a configuration
+ config_path = config if config is not None else pretrained_model_name_or_path
+ model, model_kwargs = cls.from_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ # model args
+ dtype=dtype,
+ **kwargs,
+ )
+
+ # Load model
+ pretrained_path_with_subfolder = (
+ pretrained_model_name_or_path
+ if subfolder is None
+ else os.path.join(pretrained_model_name_or_path, subfolder)
+ )
+ if os.path.isdir(pretrained_path_with_subfolder):
+ if from_pt:
+ if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
+ raise EnvironmentError(
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
+ )
+ model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
+ elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
+ # Load from a Flax checkpoint
+ model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
+ # Check if pytorch weights exist instead
+ elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
+ raise EnvironmentError(
+ f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
+ " using `from_pt=True`."
+ )
+ else:
+ raise EnvironmentError(
+ f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
+ f"{pretrained_path_with_subfolder}."
+ )
+ else:
+ try:
+ model_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
+ "login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
+ "this model name. Check the model page at "
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
+ f"{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
+ " internet connection or see how to run the library in offline mode at"
+ " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
+ )
+
+ if from_pt:
+ if is_torch_available():
+ from .modeling_utils import load_state_dict
+ else:
+ raise EnvironmentError(
+ "Can't load the model in PyTorch format because PyTorch is not installed. "
+ "Please, install PyTorch or use native Flax weights."
+ )
+
+ # Step 1: Get the pytorch file
+ pytorch_model_file = load_state_dict(model_file)
+
+ # Step 2: Convert the weights
+ state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
+ else:
+ try:
+ with open(model_file, "rb") as state_f:
+ state = from_bytes(cls, state_f.read())
+ except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
+ try:
+ with open(model_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please"
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
+ " folder you cloned."
+ )
+ else:
+ raise ValueError from e
+ except (UnicodeDecodeError, ValueError):
+ raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
+ # make sure all arrays are stored as jnp.ndarray
+ # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
+ # https://github.com/google/flax/issues/1261
+ state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
+
+ # flatten dicts
+ state = flatten_dict(state)
+
+ params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
+ required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
+
+ shape_state = flatten_dict(unfreeze(params_shape_tree))
+
+ missing_keys = required_params - set(state.keys())
+ unexpected_keys = set(state.keys()) - required_params
+
+ if missing_keys:
+ logger.warning(
+ f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
+ "Make sure to call model.init_weights to initialize the missing weights."
+ )
+ cls._missing_keys = missing_keys
+
+ for key in state.keys():
+ if key in shape_state and state[key].shape != shape_state[key].shape:
+ raise ValueError(
+ f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
+ f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
+ )
+
+ # remove unexpected keys to not be saved again
+ for unexpected_key in unexpected_keys:
+ del state[unexpected_key]
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+ " with another architecture."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ else:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+ " training."
+ )
+
+ return model, unflatten_dict(state)
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ params: Union[Dict, FrozenDict],
+ is_main_process: bool = True,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~FlaxModelMixin.from_pretrained`]` class method
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ params (`Union[Dict, FrozenDict]`):
+ A `PyTree` of model parameters.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = self
+
+ # Attach architecture to the config
+ # Save the config
+ if is_main_process:
+ model_to_save.save_config(save_directory)
+
+ # save model
+ output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
+ with open(output_model_file, "wb") as f:
+ model_bytes = to_bytes(params)
+ f.write(model_bytes)
+
+ logger.info(f"Model weights saved in {output_model_file}")
diff --git a/diffusers/modeling_utils.py b/diffusers/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e270f75e056e9130ae9a7df590a1e7547efceee8
--- /dev/null
+++ b/diffusers/modeling_utils.py
@@ -0,0 +1,764 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from functools import partial
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor, device
+
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from requests import HTTPError
+
+from . import __version__
+from .utils import (
+ CONFIG_NAME,
+ DIFFUSERS_CACHE,
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
+ SAFETENSORS_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ is_accelerate_available,
+ is_safetensors_available,
+ is_torch_version,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_torch_version(">=", "1.9.0"):
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
+else:
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
+
+
+if is_accelerate_available():
+ import accelerate
+ from accelerate.utils import set_module_tensor_to_device
+ from accelerate.utils.versions import is_torch_version
+
+if is_safetensors_available():
+ import safetensors
+
+
+def get_parameter_device(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).device
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].device
+
+
+def get_parameter_dtype(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).dtype
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].dtype
+
+
+def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
+ """
+ Reads a checkpoint file, returning properly formatted errors if they arise.
+ """
+ try:
+ if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
+ return torch.load(checkpoint_file, map_location="cpu")
+ else:
+ return safetensors.torch.load_file(checkpoint_file, device="cpu")
+ except Exception as e:
+ try:
+ with open(checkpoint_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+ "you cloned."
+ )
+ else:
+ raise ValueError(
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
+ "model. Make sure you have saved the model properly."
+ ) from e
+ except (UnicodeDecodeError, ValueError):
+ raise OSError(
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
+ f"at '{checkpoint_file}'. "
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
+ )
+
+
+def _load_state_dict_into_model(model_to_load, state_dict):
+ # Convert old format to new format if needed from a PyTorch state_dict
+ # copy state_dict so _load_from_state_dict can modify it
+ state_dict = state_dict.copy()
+ error_msgs = []
+
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+ # so we need to apply the function recursively.
+ def load(module: torch.nn.Module, prefix=""):
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
+ module._load_from_state_dict(*args)
+
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ load(model_to_load)
+
+ return error_msgs
+
+
+class ModelMixin(torch.nn.Module):
+ r"""
+ Base class for all models.
+
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
+ and saving models.
+
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
+ [`~modeling_utils.ModelMixin.save_pretrained`].
+ """
+ config_name = CONFIG_NAME
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
+ _supports_gradient_checkpointing = False
+
+ def __init__(self):
+ super().__init__()
+
+ @property
+ def is_gradient_checkpointing(self) -> bool:
+ """
+ Whether gradient checkpointing is activated for this model or not.
+
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+ activations".
+ """
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
+
+ def enable_gradient_checkpointing(self):
+ """
+ Activates gradient checkpointing for the current model.
+
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+ activations".
+ """
+ if not self._supports_gradient_checkpointing:
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
+
+ def disable_gradient_checkpointing(self):
+ """
+ Deactivates gradient checkpointing for the current model.
+
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+ activations".
+ """
+ if self._supports_gradient_checkpointing:
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Callable = None,
+ safe_serialization: bool = False,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `False`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+ """
+ if safe_serialization and not is_safetensors_available():
+ raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
+
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ if save_function is None:
+ save_function = safetensors.torch.save_file if safe_serialization else torch.save
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = self
+
+ # Attach architecture to the config
+ # Save the config
+ if is_main_process:
+ model_to_save.save_config(save_directory)
+
+ # Save the model
+ state_dict = model_to_save.state_dict()
+
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
+
+ # Clean the folder from a previous save
+ for filename in os.listdir(save_directory):
+ full_filename = os.path.join(save_directory, filename)
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
+ # in distributed settings to avoid race conditions.
+ weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
+ if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
+ os.remove(full_filename)
+
+ # Save the model
+ save_function(state_dict, os.path.join(save_directory, weights_name))
+
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
+ setting this argument to `True` will raise an error.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ subfolder = kwargs.pop("subfolder", None)
+ device_map = kwargs.pop("device_map", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if device_map is not None and not is_accelerate_available():
+ raise NotImplementedError(
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
+ )
+
+ # Check if we can handle device_map and dispatching the weights
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `device_map=None`."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ if low_cpu_mem_usage is False and device_map is not None:
+ raise ValueError(
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
+ )
+
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "model",
+ "framework": "pytorch",
+ }
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+ # Load model
+
+ model_file = None
+ if is_safetensors_available():
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=SAFETENSORS_WEIGHTS_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ except:
+ pass
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=WEIGHTS_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+
+ if low_cpu_mem_usage:
+ # Instantiate model with empty weights
+ with accelerate.init_empty_weights():
+ config, unused_kwargs = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ device_map=device_map,
+ **kwargs,
+ )
+ model = cls.from_config(config, **unused_kwargs)
+
+ # if device_map is Non,e load the state dict on move the params from meta device to the cpu
+ if device_map is None:
+ param_device = "cpu"
+ state_dict = load_state_dict(model_file)
+ # move the parms from meta device to cpu
+ for param_name, param in state_dict.items():
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
+ else: # else let accelerate handle loading and dispatching.
+ # Load weights and dispatch according to the device_map
+ # by deafult the device_map is None and the weights are loaded on the CPU
+ accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
+
+ loading_info = {
+ "missing_keys": [],
+ "unexpected_keys": [],
+ "mismatched_keys": [],
+ "error_msgs": [],
+ }
+ else:
+ config, unused_kwargs = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ device_map=device_map,
+ **kwargs,
+ )
+ model = cls.from_config(config, **unused_kwargs)
+
+ state_dict = load_state_dict(model_file)
+ dtype = set(v.dtype for v in state_dict.values())
+
+ if len(dtype) > 1 and torch.float32 not in dtype:
+ raise ValueError(
+ f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
+ f" make sure that {model_file} weights have only one dtype."
+ )
+ elif len(dtype) > 1 and torch.float32 in dtype:
+ dtype = torch.float32
+ else:
+ dtype = dtype.pop()
+
+ # move model to correct dtype
+ model = model.to(dtype)
+
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
+ model,
+ state_dict,
+ model_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ )
+
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ "error_msgs": error_msgs,
+ }
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ elif torch_dtype is not None:
+ model = model.to(torch_dtype)
+
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+ if output_loading_info:
+ return model, loading_info
+
+ return model
+
+ @classmethod
+ def _load_pretrained_model(
+ cls,
+ model,
+ state_dict,
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=False,
+ ):
+ # Retrieve missing & unexpected_keys
+ model_state_dict = model.state_dict()
+ loaded_keys = [k for k in state_dict.keys()]
+
+ expected_keys = list(model_state_dict.keys())
+
+ original_loaded_keys = loaded_keys
+
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ model_to_load = model
+
+ def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ ):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+
+ if (
+ model_key in model_state_dict
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
+ ):
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+ if state_dict is not None:
+ # Whole checkpoint
+ mismatched_keys = _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ original_loaded_keys,
+ ignore_mismatched_sizes,
+ )
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
+
+ if len(error_msgs) > 0:
+ error_msg = "\n\t".join(error_msgs)
+ if "size mismatch" in error_msg:
+ error_msg += (
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+ )
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
+ " identical (initializing a BertForSequenceClassification model from a"
+ " BertForSequenceClassification model)."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
+ " without further training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
+ " able to use it for predictions and inference."
+ )
+
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+
+ @property
+ def device(self) -> device:
+ """
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
+ device).
+ """
+ return get_parameter_device(self)
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+ """
+ return get_parameter_dtype(self)
+
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
+ """
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
+
+ Args:
+ only_trainable (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of trainable parameters
+
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of non-embeddings parameters
+
+ Returns:
+ `int`: The number of parameters.
+ """
+
+ if exclude_embeddings:
+ embedding_param_names = [
+ f"{name}.weight"
+ for name, module_type in self.named_modules()
+ if isinstance(module_type, torch.nn.Embedding)
+ ]
+ non_embedding_parameters = [
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
+ ]
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
+ else:
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
+
+
+def _get_model_file(
+ pretrained_model_name_or_path,
+ *,
+ weights_name,
+ subfolder,
+ cache_dir,
+ force_download,
+ proxies,
+ resume_download,
+ local_files_only,
+ use_auth_token,
+ user_agent,
+ revision,
+):
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
+ # Load from a PyTorch checkpoint
+ model_file = os.path.join(pretrained_model_name_or_path, weights_name)
+ return model_file
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
+ ):
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
+ return model_file
+ else:
+ raise EnvironmentError(
+ f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ model_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=weights_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+ return model_file
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
+ "login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
+ "this model name. Check the model page at "
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a file named {weights_name} or"
+ " \nCheckout your internet connection or see how to run the library in"
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a file named {weights_name}"
+ )
diff --git a/diffusers/models/README.md b/diffusers/models/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..80fe0bc381406457665d632816891fe364efd71f
--- /dev/null
+++ b/diffusers/models/README.md
@@ -0,0 +1,3 @@
+# Models
+
+For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models).
\ No newline at end of file
diff --git a/diffusers/models/__init__.py b/diffusers/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b101d1691483ea051352216589cdf7cebfed81e
--- /dev/null
+++ b/diffusers/models/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..utils import is_flax_available, is_torch_available
+
+
+if is_torch_available():
+ from .attention import Transformer2DModel
+ from .unet_1d import UNet1DModel
+ from .unet_2d import UNet2DModel
+ from .unet_2d_condition import UNet2DConditionModel
+ from .vae import AutoencoderKL, VQModel
+
+if is_flax_available():
+ from .unet_2d_condition_flax import FlaxUNet2DConditionModel
+ from .vae_flax import FlaxAutoencoderKL
diff --git a/diffusers/models/attention.py b/diffusers/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ad0af18c1c935e96de4e2928a2977f2f96e7ea6
--- /dev/null
+++ b/diffusers/models/attention.py
@@ -0,0 +1,833 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+import warnings
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..models.embeddings import ImagePositionalEmbeddings
+from ..utils import BaseOutput
+from ..utils.import_utils import is_xformers_available
+
+
+@dataclass
+class Transformer2DModelOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
+ for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+class Transformer2DModel(ModelMixin, ConfigMixin):
+ """
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
+ embeddings) inputs.
+
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
+ transformer action. Finally, reshape to image.
+
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
+ classes of unnoised image.
+
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ Pass if the input is continuous. The number of channels in the input and output.
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
+ `ImagePositionalEmbeddings`.
+ num_vector_embeds (`int`, *optional*):
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
+ up to but not more than steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = in_channels is not None
+ self.is_input_vectorized = num_vector_embeds is not None
+
+ if self.is_input_continuous and self.is_input_vectorized:
+ raise ValueError(
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is None."
+ )
+ elif not self.is_input_continuous and not self.is_input_vectorized:
+ raise ValueError(
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
+ )
+
+ # 2. Define input layers
+ if self.is_input_continuous:
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = sample_size
+ self.width = sample_size
+ self.num_vector_embeds = num_vector_embeds
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
+ )
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ if self.is_input_continuous:
+ if use_linear_projection:
+ self.proj_out = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+
+ def _set_attention_slice(self, slice_size):
+ for block in self.transformer_blocks:
+ block._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
+ """
+ Args:
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
+ hidden_states
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.long`, *optional*):
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
+ tensor.
+ """
+ # 1. Input
+ if self.is_input_continuous:
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+ elif self.is_input_vectorized:
+ hidden_states = self.latent_image_embedding(hidden_states)
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
+
+ # 3. Output
+ if self.is_input_continuous:
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+
+ output = hidden_states + residual
+ elif self.is_input_vectorized:
+ hidden_states = self.norm_out(hidden_states)
+ logits = self.out(hidden_states)
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
+ logits = logits.permute(0, 2, 1)
+
+ # log(p(x_0))
+ output = F.log_softmax(logits.double(), dim=1).float()
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
+ to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ Uses three q, k, v linear layers to compute attention.
+
+ Parameters:
+ channels (`int`): The number of channels in the input and output.
+ num_head_channels (`int`, *optional*):
+ The number of channels in each head. If None, then `num_heads` = 1.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ num_head_channels: Optional[int] = None,
+ norm_num_groups: int = 32,
+ rescale_output_factor: float = 1.0,
+ eps: float = 1e-5,
+ ):
+ super().__init__()
+ self.channels = channels
+
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
+ self.num_head_size = num_head_channels
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
+
+ # define q,k,v as linear layers
+ self.query = nn.Linear(channels, channels)
+ self.key = nn.Linear(channels, channels)
+ self.value = nn.Linear(channels, channels)
+
+ self.rescale_output_factor = rescale_output_factor
+ self.proj_attn = nn.Linear(channels, channels, 1)
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.num_heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.num_heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ batch, channel, height, width = hidden_states.shape
+
+ # norm
+ hidden_states = self.group_norm(hidden_states)
+
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
+
+ # proj to q, k, v
+ query_proj = self.query(hidden_states)
+ key_proj = self.key(hidden_states)
+ value_proj = self.value(hidden_states)
+
+ scale = 1 / math.sqrt(self.channels / self.num_heads)
+
+ query_proj = self.reshape_heads_to_batch_dim(query_proj)
+ key_proj = self.reshape_heads_to_batch_dim(key_proj)
+ value_proj = self.reshape_heads_to_batch_dim(value_proj)
+
+ attention_scores = torch.baddbmm(
+ torch.empty(
+ query_proj.shape[0],
+ query_proj.shape[1],
+ key_proj.shape[1],
+ dtype=query_proj.dtype,
+ device=query_proj.device,
+ ),
+ query_proj,
+ key_proj.transpose(-1, -2),
+ beta=0,
+ alpha=scale,
+ )
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
+ hidden_states = torch.bmm(attention_probs, value_proj)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+
+ # compute next hidden_states
+ hidden_states = self.proj_attn(hidden_states)
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
+
+ # res connect and rescale
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
+ return hidden_states
+
+
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ self.attn1 = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ ) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ ) # is self-attn if context is none
+
+ # layer norms
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+
+ # if xformers is installed try to use memory_efficient_attention by default
+ if is_xformers_available():
+ try:
+ self.set_use_memory_efficient_attention_xformers(True)
+ except Exception as e:
+ warnings.warn(
+ "Could not enable memory efficient attention. Make sure xformers is installed"
+ f" correctly and a GPU is available: {e}"
+ )
+
+ def _set_attention_slice(self, slice_size):
+ self.attn1._slice_size = slice_size
+ self.attn2._slice_size = slice_size
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
+ if not is_xformers_available():
+ print("Here is how to install it")
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def forward(self, hidden_states, context=None, timestep=None):
+ # 1. Self-Attention
+ norm_hidden_states = (
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
+ )
+
+ if self.only_cross_attention:
+ hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
+ else:
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
+
+ # 2. Cross-Attention
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
+
+ # 3. Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ return hidden_states
+
+
+class CrossAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the context. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self._slice_size = None
+ self._use_memory_efficient_attention_xformers = False
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def forward(self, hidden_states, context=None, mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ query = self.to_q(hidden_states)
+ context = context if context is not None else hidden_states
+ key = self.to_k(context)
+ value = self.to_v(context)
+
+ dim = query.shape[-1]
+
+ query = self.reshape_heads_to_batch_dim(query)
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ # TODO(PVP) - mask is currently never used. Remember to re-implement when used
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+ def _attention(self, query, key, value):
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+ attention_probs = attention_scores.softmax(dim=-1)
+ # compute attention output
+
+ hidden_states = torch.bmm(attention_probs, value)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query[start_idx:end_idx],
+ key[start_idx:end_idx].transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+ attn_slice = attn_slice.softmax(dim=-1)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value):
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "geglu":
+ geglu = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ geglu = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(geglu)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out))
+
+ def forward(self, hidden_states):
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+# feedforward
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def gelu(self, gate):
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+ """
+ The approximate form of Gaussian Error Linear Unit (GELU)
+
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+
+ def forward(self, x):
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
+
+
+class AdaLayerNorm(nn.Module):
+ """
+ Norm layer modified to incorporate timestep embeddings.
+ """
+
+ def __init__(self, embedding_dim, num_embeddings):
+ super().__init__()
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
+
+ def forward(self, x, timestep):
+ emb = self.linear(self.silu(self.emb(timestep)))
+ scale, shift = torch.chunk(emb, 2)
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+
+class DualTransformer2DModel(nn.Module):
+ """
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ Pass if the input is continuous. The number of channels in the input and output.
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
+ `ImagePositionalEmbeddings`.
+ num_vector_embeds (`int`, *optional*):
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
+ up to but not more than steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ ):
+ super().__init__()
+ self.transformers = nn.ModuleList(
+ [
+ Transformer2DModel(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ in_channels=in_channels,
+ num_layers=num_layers,
+ dropout=dropout,
+ norm_num_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_bias=attention_bias,
+ sample_size=sample_size,
+ num_vector_embeds=num_vector_embeds,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ )
+ for _ in range(2)
+ ]
+ )
+
+ # Variables that can be set by a pipeline:
+
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
+ self.mix_ratio = 0.5
+
+ # The shape of `encoder_hidden_states` is expected to be
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
+ self.condition_lengths = [77, 257]
+
+ # Which transformer to use to encode which condition.
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
+ self.transformer_index_for_condition = [1, 0]
+
+ def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
+ """
+ Args:
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
+ hidden_states
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.long`, *optional*):
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
+ tensor.
+ """
+ input_states = hidden_states
+
+ encoded_states = []
+ tokens_start = 0
+ for i in range(2):
+ # for each of the two transformers, pass the corresponding condition tokens
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
+ transformer_index = self.transformer_index_for_condition[i]
+ encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[
+ 0
+ ]
+ encoded_states.append(encoded_state - input_states)
+ tokens_start += self.condition_lengths[i]
+
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
+ output_states = output_states + input_states
+
+ if not return_dict:
+ return (output_states,)
+
+ return Transformer2DModelOutput(sample=output_states)
diff --git a/diffusers/models/attention_flax.py b/diffusers/models/attention_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..71106e05452cc7525cfbb81f2ac52926887313ec
--- /dev/null
+++ b/diffusers/models/attention_flax.py
@@ -0,0 +1,298 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import flax.linen as nn
+import jax.numpy as jnp
+
+
+class FlaxAttentionBlock(nn.Module):
+ r"""
+ A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
+
+ Parameters:
+ query_dim (:obj:`int`):
+ Input hidden states dimension
+ heads (:obj:`int`, *optional*, defaults to 8):
+ Number of heads
+ dim_head (:obj:`int`, *optional*, defaults to 64):
+ Hidden states dimension inside each head
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+
+ """
+ query_dim: int
+ heads: int = 8
+ dim_head: int = 64
+ dropout: float = 0.0
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ inner_dim = self.dim_head * self.heads
+ self.scale = self.dim_head**-0.5
+
+ # Weights were exported with old names {to_q, to_k, to_v, to_out}
+ self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
+ self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
+ self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
+
+ self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
+ tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def __call__(self, hidden_states, context=None, deterministic=True):
+ context = hidden_states if context is None else context
+
+ query_proj = self.query(hidden_states)
+ key_proj = self.key(context)
+ value_proj = self.value(context)
+
+ query_states = self.reshape_heads_to_batch_dim(query_proj)
+ key_states = self.reshape_heads_to_batch_dim(key_proj)
+ value_states = self.reshape_heads_to_batch_dim(value_proj)
+
+ # compute attentions
+ attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
+ attention_scores = attention_scores * self.scale
+ attention_probs = nn.softmax(attention_scores, axis=2)
+
+ # attend to values
+ hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ hidden_states = self.proj_attn(hidden_states)
+ return hidden_states
+
+
+class FlaxBasicTransformerBlock(nn.Module):
+ r"""
+ A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
+ https://arxiv.org/abs/1706.03762
+
+
+ Parameters:
+ dim (:obj:`int`):
+ Inner hidden states dimension
+ n_heads (:obj:`int`):
+ Number of heads
+ d_head (:obj:`int`):
+ Hidden states dimension inside each head
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ only_cross_attention (`bool`, defaults to `False`):
+ Whether to only apply cross attention.
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ dim: int
+ n_heads: int
+ d_head: int
+ dropout: float = 0.0
+ only_cross_attention: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ # self attention (or cross_attention if only_cross_attention is True)
+ self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
+ # cross attention
+ self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
+ self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
+ self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
+ self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
+ self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
+
+ def __call__(self, hidden_states, context, deterministic=True):
+ # self attention
+ residual = hidden_states
+ if self.only_cross_attention:
+ hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
+ else:
+ hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
+ hidden_states = hidden_states + residual
+
+ # cross attention
+ residual = hidden_states
+ hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
+ hidden_states = hidden_states + residual
+
+ # feed forward
+ residual = hidden_states
+ hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class FlaxTransformer2DModel(nn.Module):
+ r"""
+ A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
+ https://arxiv.org/pdf/1506.02025.pdf
+
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input number of channels
+ n_heads (:obj:`int`):
+ Number of heads
+ d_head (:obj:`int`):
+ Hidden states dimension inside each head
+ depth (:obj:`int`, *optional*, defaults to 1):
+ Number of transformers block
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ use_linear_projection (`bool`, defaults to `False`): tbd
+ only_cross_attention (`bool`, defaults to `False`): tbd
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ n_heads: int
+ d_head: int
+ depth: int = 1
+ dropout: float = 0.0
+ use_linear_projection: bool = False
+ only_cross_attention: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
+
+ inner_dim = self.n_heads * self.d_head
+ if self.use_linear_projection:
+ self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
+ else:
+ self.proj_in = nn.Conv(
+ inner_dim,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+
+ self.transformer_blocks = [
+ FlaxBasicTransformerBlock(
+ inner_dim,
+ self.n_heads,
+ self.d_head,
+ dropout=self.dropout,
+ only_cross_attention=self.only_cross_attention,
+ dtype=self.dtype,
+ )
+ for _ in range(self.depth)
+ ]
+
+ if self.use_linear_projection:
+ self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
+ else:
+ self.proj_out = nn.Conv(
+ inner_dim,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states, context, deterministic=True):
+ batch, height, width, channels = hidden_states.shape
+ residual = hidden_states
+ hidden_states = self.norm(hidden_states)
+ if self.use_linear_projection:
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
+ hidden_states = self.proj_in(hidden_states)
+ else:
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
+
+ for transformer_block in self.transformer_blocks:
+ hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
+
+ if self.use_linear_projection:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
+ else:
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states + residual
+ return hidden_states
+
+
+class FlaxGluFeedForward(nn.Module):
+ r"""
+ Flax module that encapsulates two Linear layers separated by a gated linear unit activation from:
+ https://arxiv.org/abs/2002.05202
+
+ Parameters:
+ dim (:obj:`int`):
+ Inner hidden states dimension
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ dim: int
+ dropout: float = 0.0
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ # The second linear layer needs to be called
+ # net_2 for now to match the index of the Sequential layer
+ self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
+
+ def __call__(self, hidden_states, deterministic=True):
+ hidden_states = self.net_0(hidden_states)
+ hidden_states = self.net_2(hidden_states)
+ return hidden_states
+
+
+class FlaxGEGLU(nn.Module):
+ r"""
+ Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
+ https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim (:obj:`int`):
+ Input hidden states dimension
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ dim: int
+ dropout: float = 0.0
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ inner_dim = self.dim * 4
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
+
+ def __call__(self, hidden_states, deterministic=True):
+ hidden_states = self.proj(hidden_states)
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
+ return hidden_linear * nn.gelu(hidden_gelu)
diff --git a/diffusers/models/embeddings.py b/diffusers/models/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..0221d891f171fa18f7d5648c7f6a3bbc0b1c4c90
--- /dev/null
+++ b/diffusers/models/embeddings.py
@@ -0,0 +1,200 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import numpy as np
+import torch
+from torch import nn
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
+ self.act = None
+ if act_fn == "silu":
+ self.act = nn.SiLU()
+ elif act_fn == "mish":
+ self.act = nn.Mish()
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
+
+ def forward(self, sample):
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+ return sample
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
+ return t_emb
+
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
+ ):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.log = log
+ self.flip_sin_to_cos = flip_sin_to_cos
+
+ if set_W_to_weight:
+ # to delete later
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ self.weight = self.W
+
+ def forward(self, x):
+ if self.log:
+ x = torch.log(x)
+
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
+
+ if self.flip_sin_to_cos:
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
+ else:
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+ return out
+
+
+class ImagePositionalEmbeddings(nn.Module):
+ """
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
+ height and width of the latent space.
+
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
+
+ For VQ-diffusion:
+
+ Output vector embeddings are used as input for the transformer.
+
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
+
+ Args:
+ num_embed (`int`):
+ Number of embeddings for the latent pixels embeddings.
+ height (`int`):
+ Height of the latent image i.e. the number of height embeddings.
+ width (`int`):
+ Width of the latent image i.e. the number of width embeddings.
+ embed_dim (`int`):
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
+ """
+
+ def __init__(
+ self,
+ num_embed: int,
+ height: int,
+ width: int,
+ embed_dim: int,
+ ):
+ super().__init__()
+
+ self.height = height
+ self.width = width
+ self.num_embed = num_embed
+ self.embed_dim = embed_dim
+
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
+ self.height_emb = nn.Embedding(self.height, embed_dim)
+ self.width_emb = nn.Embedding(self.width, embed_dim)
+
+ def forward(self, index):
+ emb = self.emb(index)
+
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
+
+ # 1 x H x D -> 1 x H x 1 x D
+ height_emb = height_emb.unsqueeze(2)
+
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
+
+ # 1 x W x D -> 1 x 1 x W x D
+ width_emb = width_emb.unsqueeze(1)
+
+ pos_emb = height_emb + width_emb
+
+ # 1 x H x W x D -> 1 x L xD
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
+
+ emb = emb + pos_emb[:, : emb.shape[1], :]
+
+ return emb
diff --git a/diffusers/models/embeddings_flax.py b/diffusers/models/embeddings_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbb1dbab3cb4767018ecbffa6cf8d2f4c0527fa5
--- /dev/null
+++ b/diffusers/models/embeddings_flax.py
@@ -0,0 +1,94 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import flax.linen as nn
+import jax.numpy as jnp
+
+
+def get_sinusoidal_embeddings(
+ timesteps: jnp.ndarray,
+ embedding_dim: int,
+ freq_shift: float = 1,
+ min_timescale: float = 1,
+ max_timescale: float = 1.0e4,
+ flip_sin_to_cos: bool = False,
+ scale: float = 1.0,
+) -> jnp.ndarray:
+ """Returns the positional encoding (same as Tensor2Tensor).
+ Args:
+ timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ embedding_dim: The number of output channels.
+ min_timescale: The smallest time unit (should probably be 0.0).
+ max_timescale: The largest time unit.
+ Returns:
+ a Tensor of timing signals [N, num_channels]
+ """
+ assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
+ assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
+ num_timescales = float(embedding_dim // 2)
+ log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
+ inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
+ emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
+
+ # scale embeddings
+ scaled_time = scale * emb
+
+ if flip_sin_to_cos:
+ signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
+ else:
+ signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
+ signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
+ return signal
+
+
+class FlaxTimestepEmbedding(nn.Module):
+ r"""
+ Time step Embedding Module. Learns embeddings for input time steps.
+
+ Args:
+ time_embed_dim (`int`, *optional*, defaults to `32`):
+ Time step embedding dimension
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ time_embed_dim: int = 32
+ dtype: jnp.dtype = jnp.float32
+
+ @nn.compact
+ def __call__(self, temb):
+ temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
+ temb = nn.silu(temb)
+ temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
+ return temb
+
+
+class FlaxTimesteps(nn.Module):
+ r"""
+ Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
+
+ Args:
+ dim (`int`, *optional*, defaults to `32`):
+ Time step embedding dimension
+ """
+ dim: int = 32
+ flip_sin_to_cos: bool = False
+ freq_shift: float = 1
+
+ @nn.compact
+ def __call__(self, timesteps):
+ return get_sinusoidal_embeddings(
+ timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
+ )
diff --git a/diffusers/models/resnet.py b/diffusers/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..52d056ae96fb39e00558b199c93d8cf25996bafb
--- /dev/null
+++ b/diffusers/models/resnet.py
@@ -0,0 +1,665 @@
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Upsample1D(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ Parameters:
+ channels: channels in the inputs and outputs.
+ use_conv: a bool determining if a convolution is applied.
+ use_conv_transpose:
+ out_channels:
+ """
+
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ self.conv = None
+ if use_conv_transpose:
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.use_conv_transpose:
+ return self.conv(x)
+
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+
+ if self.use_conv:
+ x = self.conv(x)
+
+ return x
+
+
+class Downsample1D(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ Parameters:
+ channels: channels in the inputs and outputs.
+ use_conv: a bool determining if a convolution is applied.
+ out_channels:
+ padding:
+ """
+
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.conv(x)
+
+
+class Upsample2D(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ Parameters:
+ channels: channels in the inputs and outputs.
+ use_conv: a bool determining if a convolution is applied.
+ use_conv_transpose:
+ out_channels:
+ """
+
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, hidden_states, output_size=None):
+ assert hidden_states.shape[1] == self.channels
+
+ if self.use_conv_transpose:
+ return self.conv(hidden_states)
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
+ # https://github.com/pytorch/pytorch/issues/86679
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if self.use_conv:
+ if self.name == "conv":
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+
+ return hidden_states
+
+
+class Downsample2D(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ Parameters:
+ channels: channels in the inputs and outputs.
+ use_conv: a bool determining if a convolution is applied.
+ out_channels:
+ padding:
+ """
+
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, hidden_states):
+ assert hidden_states.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ pad = (0, 1, 0, 1)
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
+
+ assert hidden_states.shape[1] == self.channels
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class FirUpsample2D(nn.Module):
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.use_conv = use_conv
+ self.fir_kernel = fir_kernel
+ self.out_channels = out_channels
+
+ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
+ """Fused `upsample_2d()` followed by `Conv2d()`.
+
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
+ arbitrary order.
+
+ Args:
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
+ datatype as `hidden_states`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+
+ # Setup filter kernel.
+ if kernel is None:
+ kernel = [1] * factor
+
+ # setup kernel
+ kernel = torch.tensor(kernel, dtype=torch.float32)
+ if kernel.ndim == 1:
+ kernel = torch.outer(kernel, kernel)
+ kernel /= torch.sum(kernel)
+
+ kernel = kernel * (gain * (factor**2))
+
+ if self.use_conv:
+ convH = weight.shape[2]
+ convW = weight.shape[3]
+ inC = weight.shape[1]
+
+ pad_value = (kernel.shape[0] - factor) - (convW - 1)
+
+ stride = (factor, factor)
+ # Determine data dimensions.
+ output_shape = (
+ (hidden_states.shape[2] - 1) * factor + convH,
+ (hidden_states.shape[3] - 1) * factor + convW,
+ )
+ output_padding = (
+ output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
+ output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
+ )
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
+ num_groups = hidden_states.shape[1] // inC
+
+ # Transpose weights.
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
+ weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
+
+ inverse_conv = F.conv_transpose2d(
+ hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
+ )
+
+ output = upfirdn2d_native(
+ inverse_conv,
+ torch.tensor(kernel, device=inverse_conv.device),
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
+ )
+ else:
+ pad_value = kernel.shape[0] - factor
+ output = upfirdn2d_native(
+ hidden_states,
+ torch.tensor(kernel, device=hidden_states.device),
+ up=factor,
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
+ )
+
+ return output
+
+ def forward(self, hidden_states):
+ if self.use_conv:
+ height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
+
+ return height
+
+
+class FirDownsample2D(nn.Module):
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.fir_kernel = fir_kernel
+ self.use_conv = use_conv
+ self.out_channels = out_channels
+
+ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
+ """Fused `Conv2d()` followed by `downsample_2d()`.
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
+ arbitrary order.
+
+ Args:
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ weight:
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
+ performed by `inChannels = x.shape[0] // numGroups`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
+ factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
+ same datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ # setup kernel
+ kernel = torch.tensor(kernel, dtype=torch.float32)
+ if kernel.ndim == 1:
+ kernel = torch.outer(kernel, kernel)
+ kernel /= torch.sum(kernel)
+
+ kernel = kernel * gain
+
+ if self.use_conv:
+ _, _, convH, convW = weight.shape
+ pad_value = (kernel.shape[0] - factor) + (convW - 1)
+ stride_value = [factor, factor]
+ upfirdn_input = upfirdn2d_native(
+ hidden_states,
+ torch.tensor(kernel, device=hidden_states.device),
+ pad=((pad_value + 1) // 2, pad_value // 2),
+ )
+ output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
+ else:
+ pad_value = kernel.shape[0] - factor
+ output = upfirdn2d_native(
+ hidden_states,
+ torch.tensor(kernel, device=hidden_states.device),
+ down=factor,
+ pad=((pad_value + 1) // 2, pad_value // 2),
+ )
+
+ return output
+
+ def forward(self, hidden_states):
+ if self.use_conv:
+ downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
+
+ return hidden_states
+
+
+class ResnetBlock2D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ kernel=None,
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.up = up
+ self.down = down
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.upsample = self.downsample = None
+ if self.up:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
+ else:
+ self.upsample = Upsample2D(in_channels, use_conv=False)
+ elif self.down:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
+ else:
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
+
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ input_tensor = input_tensor.contiguous()
+ hidden_states = hidden_states.contiguous()
+ input_tensor = self.upsample(input_tensor)
+ hidden_states = self.upsample(hidden_states)
+ elif self.downsample is not None:
+ input_tensor = self.downsample(input_tensor)
+ hidden_states = self.downsample(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+class Mish(torch.nn.Module):
+ def forward(self, hidden_states):
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
+
+
+# unet_rl.py
+def rearrange_dims(tensor):
+ if len(tensor.shape) == 2:
+ return tensor[:, :, None]
+ if len(tensor.shape) == 3:
+ return tensor[:, :, None, :]
+ elif len(tensor.shape) == 4:
+ return tensor[:, :, 0, :]
+ else:
+ raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
+
+
+class Conv1dBlock(nn.Module):
+ """
+ Conv1d --> GroupNorm --> Mish
+ """
+
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
+ super().__init__()
+
+ self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
+ self.mish = nn.Mish()
+
+ def forward(self, x):
+ x = self.conv1d(x)
+ x = rearrange_dims(x)
+ x = self.group_norm(x)
+ x = rearrange_dims(x)
+ x = self.mish(x)
+ return x
+
+
+# unet_rl.py
+class ResidualTemporalBlock1D(nn.Module):
+ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
+ super().__init__()
+ self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
+ self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
+
+ self.time_emb_act = nn.Mish()
+ self.time_emb = nn.Linear(embed_dim, out_channels)
+
+ self.residual_conv = (
+ nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
+ )
+
+ def forward(self, x, t):
+ """
+ Args:
+ x : [ batch_size x inp_channels x horizon ]
+ t : [ batch_size x embed_dim ]
+
+ returns:
+ out : [ batch_size x out_channels x horizon ]
+ """
+ t = self.time_emb_act(t)
+ t = self.time_emb(t)
+ out = self.conv_in(x) + rearrange_dims(t)
+ out = self.conv_out(out)
+ return out + self.residual_conv(x)
+
+
+def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
+ r"""Upsample2D a batch of 2D images with the given filter.
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
+ a: multiple of the upsampling factor.
+
+ Args:
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
+ """
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ kernel = torch.tensor(kernel, dtype=torch.float32)
+ if kernel.ndim == 1:
+ kernel = torch.outer(kernel, kernel)
+ kernel /= torch.sum(kernel)
+
+ kernel = kernel * (gain * (factor**2))
+ pad_value = kernel.shape[0] - factor
+ output = upfirdn2d_native(
+ hidden_states,
+ kernel.to(device=hidden_states.device),
+ up=factor,
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
+ )
+ return output
+
+
+def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
+ r"""Downsample2D a batch of 2D images with the given filter.
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
+ shape is a multiple of the downsampling factor.
+
+ Args:
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ kernel = torch.tensor(kernel, dtype=torch.float32)
+ if kernel.ndim == 1:
+ kernel = torch.outer(kernel, kernel)
+ kernel /= torch.sum(kernel)
+
+ kernel = kernel * gain
+ pad_value = kernel.shape[0] - factor
+ output = upfirdn2d_native(
+ hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
+ )
+ return output
+
+
+def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
+ up_x = up_y = up
+ down_x = down_y = down
+ pad_x0 = pad_y0 = pad[0]
+ pad_x1 = pad_y1 = pad[1]
+
+ _, channel, in_h, in_w = tensor.shape
+ tensor = tensor.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = tensor.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = tensor.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out.to(tensor.device) # Move back to mps if necessary
+ out = out[
+ :,
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/diffusers/models/resnet_flax.py b/diffusers/models/resnet_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..632780378ee0e8fa49404ecae470146250270ce5
--- /dev/null
+++ b/diffusers/models/resnet_flax.py
@@ -0,0 +1,124 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+
+
+class FlaxUpsample2D(nn.Module):
+ out_channels: int
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.conv = nn.Conv(
+ self.out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states):
+ batch, height, width, channels = hidden_states.shape
+ hidden_states = jax.image.resize(
+ hidden_states,
+ shape=(batch, height * 2, width * 2, channels),
+ method="nearest",
+ )
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class FlaxDownsample2D(nn.Module):
+ out_channels: int
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.conv = nn.Conv(
+ self.out_channels,
+ kernel_size=(3, 3),
+ strides=(2, 2),
+ padding=((1, 1), (1, 1)), # padding="VALID",
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states):
+ # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
+ # hidden_states = jnp.pad(hidden_states, pad_width=pad)
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class FlaxResnetBlock2D(nn.Module):
+ in_channels: int
+ out_channels: int = None
+ dropout_prob: float = 0.0
+ use_nin_shortcut: bool = None
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ out_channels = self.in_channels if self.out_channels is None else self.out_channels
+
+ self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
+ self.conv1 = nn.Conv(
+ out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
+
+ self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
+ self.dropout = nn.Dropout(self.dropout_prob)
+ self.conv2 = nn.Conv(
+ out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
+
+ self.conv_shortcut = None
+ if use_nin_shortcut:
+ self.conv_shortcut = nn.Conv(
+ out_channels,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states, temb, deterministic=True):
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = nn.swish(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ temb = self.time_emb_proj(nn.swish(temb))
+ temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = nn.swish(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ residual = self.conv_shortcut(residual)
+
+ return hidden_states + residual
diff --git a/diffusers/models/unet_1d.py b/diffusers/models/unet_1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..29d1d707f55a026458defd2bc0ec089ecc10653a
--- /dev/null
+++ b/diffusers/models/unet_1d.py
@@ -0,0 +1,245 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput
+from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
+from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
+
+
+@dataclass
+class UNet1DOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
+ Hidden states output. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class UNet1DModel(ModelMixin, ConfigMixin):
+ r"""
+ UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
+ in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
+ time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
+ freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding.
+ flip_sin_to_cos (`bool`, *optional*, defaults to :
+ obj:`False`): Whether to flip sin to cos for fourier time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(32, 32, 64)`): Tuple of block output channels.
+ mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet.
+ out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet.
+ act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks.
+ norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks.
+ layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block.
+ downsample_each_block (`int`, *optional*, defaults to False:
+ experimental feature for using a UNet without upsampling.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 65536,
+ sample_rate: Optional[int] = None,
+ in_channels: int = 2,
+ out_channels: int = 2,
+ extra_in_channels: int = 0,
+ time_embedding_type: str = "fourier",
+ flip_sin_to_cos: bool = True,
+ use_timestep_embedding: bool = False,
+ freq_shift: float = 0.0,
+ down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
+ up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
+ mid_block_type: Tuple[str] = "UNetMidBlock1D",
+ out_block_type: str = None,
+ block_out_channels: Tuple[int] = (32, 32, 64),
+ act_fn: str = None,
+ norm_num_groups: int = 8,
+ layers_per_block: int = 1,
+ downsample_each_block: bool = False,
+ ):
+ super().__init__()
+ self.sample_size = sample_size
+
+ # time
+ if time_embedding_type == "fourier":
+ self.time_proj = GaussianFourierProjection(
+ embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = 2 * block_out_channels[0]
+ elif time_embedding_type == "positional":
+ self.time_proj = Timesteps(
+ block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
+ )
+ timestep_input_dim = block_out_channels[0]
+
+ if use_timestep_embedding:
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_mlp = TimestepEmbedding(
+ in_channels=timestep_input_dim,
+ time_embed_dim=time_embed_dim,
+ act_fn=act_fn,
+ out_dim=block_out_channels[0],
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+ self.out_block = None
+
+ # down
+ output_channel = in_channels
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+
+ if i == 0:
+ input_channel += extra_in_channels
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=block_out_channels[0],
+ add_downsample=not is_final_block or downsample_each_block,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = get_mid_block(
+ mid_block_type,
+ in_channels=block_out_channels[-1],
+ mid_channels=block_out_channels[-1],
+ out_channels=block_out_channels[-1],
+ embed_dim=block_out_channels[0],
+ num_layers=layers_per_block,
+ add_downsample=downsample_each_block,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ if out_block_type is None:
+ final_upsample_channels = out_channels
+ else:
+ final_upsample_channels = block_out_channels[0]
+
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = (
+ reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
+ )
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ temb_channels=block_out_channels[0],
+ add_upsample=not is_final_block,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
+ self.out_block = get_out_block(
+ out_block_type=out_block_type,
+ num_groups_out=num_groups_out,
+ embed_dim=block_out_channels[0],
+ out_channels=out_channels,
+ act_fn=act_fn,
+ fc_dim=block_out_channels[-1] // 4,
+ )
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ return_dict: bool = True,
+ ) -> Union[UNet1DOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): `(batch_size, sample_size, num_channels)` noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
+ otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+ """
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ timestep_embed = self.time_proj(timesteps)
+ if self.config.use_timestep_embedding:
+ timestep_embed = self.time_mlp(timestep_embed)
+ else:
+ timestep_embed = timestep_embed[..., None]
+ timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
+
+ # 2. down
+ down_block_res_samples = ()
+ for downsample_block in self.down_blocks:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
+ down_block_res_samples += res_samples
+
+ # 3. mid
+ if self.mid_block:
+ sample = self.mid_block(sample, timestep_embed)
+
+ # 4. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ res_samples = down_block_res_samples[-1:]
+ down_block_res_samples = down_block_res_samples[:-1]
+ sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
+
+ # 5. post-process
+ if self.out_block:
+ sample = self.out_block(sample, timestep_embed)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet1DOutput(sample=sample)
diff --git a/diffusers/models/unet_1d_blocks.py b/diffusers/models/unet_1d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc758ebbb044644e921c7e66089e052981a82e1e
--- /dev/null
+++ b/diffusers/models/unet_1d_blocks.py
@@ -0,0 +1,668 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
+
+
+class DownResnetBlock1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels=None,
+ num_layers=1,
+ conv_shortcut=False,
+ temb_channels=32,
+ groups=32,
+ groups_out=None,
+ non_linearity=None,
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ add_downsample=True,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.add_downsample = add_downsample
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ # there will always be at least one resnet
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
+
+ for _ in range(num_layers):
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = nn.Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+ else:
+ self.nonlinearity = None
+
+ self.downsample = None
+ if add_downsample:
+ self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for resnet in self.resnets[1:]:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.nonlinearity is not None:
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.downsample is not None:
+ hidden_states = self.downsample(hidden_states)
+
+ return hidden_states, output_states
+
+
+class UpResnetBlock1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels=None,
+ num_layers=1,
+ temb_channels=32,
+ groups=32,
+ groups_out=None,
+ non_linearity=None,
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.time_embedding_norm = time_embedding_norm
+ self.add_upsample = add_upsample
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ # there will always be at least one resnet
+ resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
+
+ for _ in range(num_layers):
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = nn.Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+ else:
+ self.nonlinearity = None
+
+ self.upsample = None
+ if add_upsample:
+ self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
+
+ def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
+ if res_hidden_states_tuple is not None:
+ res_hidden_states = res_hidden_states_tuple[-1]
+ hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for resnet in self.resnets[1:]:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.nonlinearity is not None:
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ hidden_states = self.upsample(hidden_states)
+
+ return hidden_states
+
+
+class ValueFunctionMidBlock1D(nn.Module):
+ def __init__(self, in_channels, out_channels, embed_dim):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.embed_dim = embed_dim
+
+ self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
+ self.down1 = Downsample1D(out_channels // 2, use_conv=True)
+ self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
+ self.down2 = Downsample1D(out_channels // 4, use_conv=True)
+
+ def forward(self, x, temb=None):
+ x = self.res1(x, temb)
+ x = self.down1(x)
+ x = self.res2(x, temb)
+ x = self.down2(x)
+ return x
+
+
+class MidResTemporalBlock1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ embed_dim,
+ num_layers: int = 1,
+ add_downsample: bool = False,
+ add_upsample: bool = False,
+ non_linearity=None,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.add_downsample = add_downsample
+
+ # there will always be at least one resnet
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
+
+ for _ in range(num_layers):
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = nn.Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+ else:
+ self.nonlinearity = None
+
+ self.upsample = None
+ if add_upsample:
+ self.upsample = Downsample1D(out_channels, use_conv=True)
+
+ self.downsample = None
+ if add_downsample:
+ self.downsample = Downsample1D(out_channels, use_conv=True)
+
+ if self.upsample and self.downsample:
+ raise ValueError("Block cannot downsample and upsample")
+
+ def forward(self, hidden_states, temb):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for resnet in self.resnets[1:]:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsample:
+ hidden_states = self.upsample(hidden_states)
+ if self.downsample:
+ self.downsample = self.downsample(hidden_states)
+
+ return hidden_states
+
+
+class OutConv1DBlock(nn.Module):
+ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
+ super().__init__()
+ self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
+ self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
+ if act_fn == "silu":
+ self.final_conv1d_act = nn.SiLU()
+ if act_fn == "mish":
+ self.final_conv1d_act = nn.Mish()
+ self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = self.final_conv1d_1(hidden_states)
+ hidden_states = rearrange_dims(hidden_states)
+ hidden_states = self.final_conv1d_gn(hidden_states)
+ hidden_states = rearrange_dims(hidden_states)
+ hidden_states = self.final_conv1d_act(hidden_states)
+ hidden_states = self.final_conv1d_2(hidden_states)
+ return hidden_states
+
+
+class OutValueFunctionBlock(nn.Module):
+ def __init__(self, fc_dim, embed_dim):
+ super().__init__()
+ self.final_block = nn.ModuleList(
+ [
+ nn.Linear(fc_dim + embed_dim, fc_dim // 2),
+ nn.Mish(),
+ nn.Linear(fc_dim // 2, 1),
+ ]
+ )
+
+ def forward(self, hidden_states, temb):
+ hidden_states = hidden_states.view(hidden_states.shape[0], -1)
+ hidden_states = torch.cat((hidden_states, temb), dim=-1)
+ for layer in self.final_block:
+ hidden_states = layer(hidden_states)
+
+ return hidden_states
+
+
+_kernels = {
+ "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
+ "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
+ "lanczos3": [
+ 0.003689131001010537,
+ 0.015056144446134567,
+ -0.03399861603975296,
+ -0.066637322306633,
+ 0.13550527393817902,
+ 0.44638532400131226,
+ 0.44638532400131226,
+ 0.13550527393817902,
+ -0.066637322306633,
+ -0.03399861603975296,
+ 0.015056144446134567,
+ 0.003689131001010537,
+ ],
+}
+
+
+class Downsample1d(nn.Module):
+ def __init__(self, kernel="linear", pad_mode="reflect"):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = torch.tensor(_kernels[kernel])
+ self.pad = kernel_1d.shape[0] // 2 - 1
+ self.register_buffer("kernel", kernel_1d)
+
+ def forward(self, hidden_states):
+ hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
+ weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
+ indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
+ weight[indices, indices] = self.kernel.to(weight)
+ return F.conv1d(hidden_states, weight, stride=2)
+
+
+class Upsample1d(nn.Module):
+ def __init__(self, kernel="linear", pad_mode="reflect"):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
+ self.pad = kernel_1d.shape[0] // 2 - 1
+ self.register_buffer("kernel", kernel_1d)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
+ weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
+ indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
+ weight[indices, indices] = self.kernel.to(weight)
+ return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
+
+
+class SelfAttention1d(nn.Module):
+ def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
+ super().__init__()
+ self.channels = in_channels
+ self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
+ self.num_heads = n_head
+
+ self.query = nn.Linear(self.channels, self.channels)
+ self.key = nn.Linear(self.channels, self.channels)
+ self.value = nn.Linear(self.channels, self.channels)
+
+ self.proj_attn = nn.Linear(self.channels, self.channels, 1)
+
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
+
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
+ return new_projection
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ batch, channel_dim, seq = hidden_states.shape
+
+ hidden_states = self.group_norm(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+
+ query_proj = self.query(hidden_states)
+ key_proj = self.key(hidden_states)
+ value_proj = self.value(hidden_states)
+
+ query_states = self.transpose_for_scores(query_proj)
+ key_states = self.transpose_for_scores(key_proj)
+ value_states = self.transpose_for_scores(value_proj)
+
+ scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
+
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
+ attention_probs = torch.softmax(attention_scores, dim=-1)
+
+ # compute attention output
+ hidden_states = torch.matmul(attention_probs, value_states)
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
+ hidden_states = hidden_states.view(new_hidden_states_shape)
+
+ # compute next hidden_states
+ hidden_states = self.proj_attn(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.dropout(hidden_states)
+
+ output = hidden_states + residual
+
+ return output
+
+
+class ResConvBlock(nn.Module):
+ def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
+ super().__init__()
+ self.is_last = is_last
+ self.has_conv_skip = in_channels != out_channels
+
+ if self.has_conv_skip:
+ self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
+
+ self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
+ self.group_norm_1 = nn.GroupNorm(1, mid_channels)
+ self.gelu_1 = nn.GELU()
+ self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
+
+ if not self.is_last:
+ self.group_norm_2 = nn.GroupNorm(1, out_channels)
+ self.gelu_2 = nn.GELU()
+
+ def forward(self, hidden_states):
+ residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
+
+ hidden_states = self.conv_1(hidden_states)
+ hidden_states = self.group_norm_1(hidden_states)
+ hidden_states = self.gelu_1(hidden_states)
+ hidden_states = self.conv_2(hidden_states)
+
+ if not self.is_last:
+ hidden_states = self.group_norm_2(hidden_states)
+ hidden_states = self.gelu_2(hidden_states)
+
+ output = hidden_states + residual
+ return output
+
+
+class UNetMidBlock1D(nn.Module):
+ def __init__(self, mid_channels, in_channels, out_channels=None):
+ super().__init__()
+
+ out_channels = in_channels if out_channels is None else out_channels
+
+ # there is always at least one resnet
+ self.down = Downsample1d("cubic")
+ resnets = [
+ ResConvBlock(in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels),
+ ]
+ attentions = [
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(out_channels, out_channels // 32),
+ ]
+ self.up = Upsample1d(kernel="cubic")
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = self.down(hidden_states)
+ for attn, resnet in zip(self.attentions, self.resnets):
+ hidden_states = resnet(hidden_states)
+ hidden_states = attn(hidden_states)
+
+ hidden_states = self.up(hidden_states)
+
+ return hidden_states
+
+
+class AttnDownBlock1D(nn.Module):
+ def __init__(self, out_channels, in_channels, mid_channels=None):
+ super().__init__()
+ mid_channels = out_channels if mid_channels is None else mid_channels
+
+ self.down = Downsample1d("cubic")
+ resnets = [
+ ResConvBlock(in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels),
+ ]
+ attentions = [
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(out_channels, out_channels // 32),
+ ]
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = self.down(hidden_states)
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states)
+ hidden_states = attn(hidden_states)
+
+ return hidden_states, (hidden_states,)
+
+
+class DownBlock1D(nn.Module):
+ def __init__(self, out_channels, in_channels, mid_channels=None):
+ super().__init__()
+ mid_channels = out_channels if mid_channels is None else mid_channels
+
+ self.down = Downsample1d("cubic")
+ resnets = [
+ ResConvBlock(in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels),
+ ]
+
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = self.down(hidden_states)
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states, (hidden_states,)
+
+
+class DownBlock1DNoSkip(nn.Module):
+ def __init__(self, out_channels, in_channels, mid_channels=None):
+ super().__init__()
+ mid_channels = out_channels if mid_channels is None else mid_channels
+
+ resnets = [
+ ResConvBlock(in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels),
+ ]
+
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = torch.cat([hidden_states, temb], dim=1)
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states, (hidden_states,)
+
+
+class AttnUpBlock1D(nn.Module):
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ super().__init__()
+ mid_channels = out_channels if mid_channels is None else mid_channels
+
+ resnets = [
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels),
+ ]
+ attentions = [
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(mid_channels, mid_channels // 32),
+ SelfAttention1d(out_channels, out_channels // 32),
+ ]
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.up = Upsample1d(kernel="cubic")
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ res_hidden_states = res_hidden_states_tuple[-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states)
+ hidden_states = attn(hidden_states)
+
+ hidden_states = self.up(hidden_states)
+
+ return hidden_states
+
+
+class UpBlock1D(nn.Module):
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ super().__init__()
+ mid_channels = in_channels if mid_channels is None else mid_channels
+
+ resnets = [
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels),
+ ]
+
+ self.resnets = nn.ModuleList(resnets)
+ self.up = Upsample1d(kernel="cubic")
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ res_hidden_states = res_hidden_states_tuple[-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ hidden_states = self.up(hidden_states)
+
+ return hidden_states
+
+
+class UpBlock1DNoSkip(nn.Module):
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ super().__init__()
+ mid_channels = in_channels if mid_channels is None else mid_channels
+
+ resnets = [
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
+ ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
+ ]
+
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ res_hidden_states = res_hidden_states_tuple[-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states
+
+
+def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
+ if down_block_type == "DownResnetBlock1D":
+ return DownResnetBlock1D(
+ in_channels=in_channels,
+ num_layers=num_layers,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ )
+ elif down_block_type == "DownBlock1D":
+ return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
+ elif down_block_type == "AttnDownBlock1D":
+ return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
+ elif down_block_type == "DownBlock1DNoSkip":
+ return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
+ if up_block_type == "UpResnetBlock1D":
+ return UpResnetBlock1D(
+ in_channels=in_channels,
+ num_layers=num_layers,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ )
+ elif up_block_type == "UpBlock1D":
+ return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
+ elif up_block_type == "AttnUpBlock1D":
+ return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
+ elif up_block_type == "UpBlock1DNoSkip":
+ return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
+ if mid_block_type == "MidResTemporalBlock1D":
+ return MidResTemporalBlock1D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ embed_dim=embed_dim,
+ add_downsample=add_downsample,
+ )
+ elif mid_block_type == "ValueFunctionMidBlock1D":
+ return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
+ elif mid_block_type == "UNetMidBlock1D":
+ return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
+ raise ValueError(f"{mid_block_type} does not exist.")
+
+
+def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
+ if out_block_type == "OutConv1DBlock":
+ return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
+ elif out_block_type == "ValueFunction":
+ return OutValueFunctionBlock(fc_dim, embed_dim)
+ return None
diff --git a/diffusers/models/unet_2d.py b/diffusers/models/unet_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b337f482cc561ad968f5e05e6f7a027af394e2d
--- /dev/null
+++ b/diffusers/models/unet_2d.py
@@ -0,0 +1,264 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput
+from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
+from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
+
+
+@dataclass
+class UNet2DOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Hidden states output. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class UNet2DModel(ModelMixin, ConfigMixin):
+ r"""
+ UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
+ flip_sin_to_cos (`bool`, *optional*, defaults to :
+ obj:`True`): Whether to flip sin to cos for fourier time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
+ types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(224, 448, 672, 896)`): Tuple of block output channels.
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ center_input_sample: bool = False,
+ time_embedding_type: str = "positional",
+ freq_shift: int = 0,
+ flip_sin_to_cos: bool = True,
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
+ layers_per_block: int = 2,
+ mid_block_scale_factor: float = 1,
+ downsample_padding: int = 1,
+ act_fn: str = "silu",
+ attention_head_dim: int = 8,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ if time_embedding_type == "fourier":
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
+ timestep_input_dim = 2 * block_out_channels[0]
+ elif time_embedding_type == "positional":
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attn_num_head_channels=attention_head_dim,
+ downsample_padding=downsample_padding,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=attention_head_dim,
+ resnet_groups=norm_num_groups,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attn_num_head_channels=attention_head_dim,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ return_dict: bool = True,
+ ) -> Union[UNet2DOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
+ otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+ """
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+
+ # 2. pre-process
+ skip_sample = sample
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "skip_conv"):
+ sample, res_samples, skip_sample = downsample_block(
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, emb)
+
+ # 5. up
+ skip_sample = None
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "skip_conv"):
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
+ else:
+ sample = upsample_block(sample, res_samples, emb)
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if skip_sample is not None:
+ sample += skip_sample
+
+ if self.config.time_embedding_type == "fourier":
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
+ sample = sample / timesteps
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DOutput(sample=sample)
diff --git a/diffusers/models/unet_2d_blocks.py b/diffusers/models/unet_2d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..d78804b18e7586d52288b29f0ed7d54873196d0e
--- /dev/null
+++ b/diffusers/models/unet_2d_blocks.py
@@ -0,0 +1,1650 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import torch
+from torch import nn
+
+from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
+from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+):
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock2D":
+ return DownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnDownBlock2D":
+ return AttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "CrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
+ return CrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ )
+ elif down_block_type == "SkipDownBlock2D":
+ return SkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnSkipDownBlock2D":
+ return AttnSkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "DownEncoderBlock2D":
+ return DownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnDownEncoderBlock2D":
+ return AttnDownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock2D":
+ return UpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ )
+ elif up_block_type == "CrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
+ return CrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ )
+ elif up_block_type == "AttnUpBlock2D":
+ return AttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "SkipUpBlock2D":
+ return SkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "AttnSkipUpBlock2D":
+ return AttnSkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "UpDecoderBlock2D":
+ return UpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ )
+ elif up_block_type == "AttnUpDecoderBlock2D":
+ return AttnUpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ AttentionBlock(
+ in_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if self.attention_type == "default":
+ hidden_states = attn(hidden_states)
+ else:
+ hidden_states = attn(hidden_states, encoder_states)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def set_attention_slice(self, slice_size):
+ head_dims = self.attn_num_head_channels
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a common divisor of "
+ f"the number of heads used in cross_attention: {head_dims}"
+ )
+ if slice_size is not None and slice_size > min(head_dims):
+ raise ValueError(
+ f"slice_size {slice_size} has to be smaller or equal to "
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states).sample
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class AttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def set_attention_slice(self, slice_size):
+ head_dims = self.attn_num_head_channels
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a common divisor of "
+ f"the number of heads used in cross_attention: {head_dims}"
+ )
+ if slice_size is not None and slice_size > min(head_dims):
+ raise ValueError(
+ f"slice_size {slice_size} has to be smaller or equal to "
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnDownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=np.sqrt(2.0),
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ self.attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class SkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class AttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_type="default",
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def set_attention_slice(self, slice_size):
+ head_dims = self.attn_num_head_channels
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a common divisor of "
+ f"the number of heads used in cross_attention: {head_dims}"
+ )
+ if slice_size is not None and slice_size > min(head_dims):
+ raise ValueError(
+ f"slice_size {slice_size} has to be smaller or equal to "
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnUpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=np.sqrt(2.0),
+ upsample_padding=1,
+ add_upsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = self.attentions[0](hidden_states)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
+
+
+class SkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_upsample=True,
+ upsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
diff --git a/diffusers/models/unet_2d_blocks_flax.py b/diffusers/models/unet_2d_blocks_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..96e76cb06a59a31beebf4449786b72a7c838a298
--- /dev/null
+++ b/diffusers/models/unet_2d_blocks_flax.py
@@ -0,0 +1,365 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import flax.linen as nn
+import jax.numpy as jnp
+
+from .attention_flax import FlaxTransformer2DModel
+from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
+
+
+class FlaxCrossAttnDownBlock2D(nn.Module):
+ r"""
+ Cross Attention 2D Downsizing block - original architecture from Unet transformers:
+ https://arxiv.org/abs/2103.06104
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ out_channels (:obj:`int`):
+ Output channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of attention blocks layers
+ attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
+ Number of attention heads of each spatial transformer block
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add downsampling layer before each final output
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ out_channels: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ attn_num_head_channels: int = 1
+ add_downsample: bool = True
+ use_linear_projection: bool = False
+ only_cross_attention: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnets = []
+ attentions = []
+
+ for i in range(self.num_layers):
+ in_channels = self.in_channels if i == 0 else self.out_channels
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ dropout_prob=self.dropout,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+
+ attn_block = FlaxTransformer2DModel(
+ in_channels=self.out_channels,
+ n_heads=self.attn_num_head_channels,
+ d_head=self.out_channels // self.attn_num_head_channels,
+ depth=1,
+ use_linear_projection=self.use_linear_projection,
+ only_cross_attention=self.only_cross_attention,
+ dtype=self.dtype,
+ )
+ attentions.append(attn_block)
+
+ self.resnets = resnets
+ self.attentions = attentions
+
+ if self.add_downsample:
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
+ output_states += (hidden_states,)
+
+ if self.add_downsample:
+ hidden_states = self.downsamplers_0(hidden_states)
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class FlaxDownBlock2D(nn.Module):
+ r"""
+ Flax 2D downsizing block
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ out_channels (:obj:`int`):
+ Output channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of attention blocks layers
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add downsampling layer before each final output
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ out_channels: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ add_downsample: bool = True
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnets = []
+
+ for i in range(self.num_layers):
+ in_channels = self.in_channels if i == 0 else self.out_channels
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ dropout_prob=self.dropout,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+ self.resnets = resnets
+
+ if self.add_downsample:
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, temb, deterministic=True):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
+ output_states += (hidden_states,)
+
+ if self.add_downsample:
+ hidden_states = self.downsamplers_0(hidden_states)
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class FlaxCrossAttnUpBlock2D(nn.Module):
+ r"""
+ Cross Attention 2D Upsampling block - original architecture from Unet transformers:
+ https://arxiv.org/abs/2103.06104
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ out_channels (:obj:`int`):
+ Output channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of attention blocks layers
+ attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
+ Number of attention heads of each spatial transformer block
+ add_upsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add upsampling layer before each final output
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ out_channels: int
+ prev_output_channel: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ attn_num_head_channels: int = 1
+ add_upsample: bool = True
+ use_linear_projection: bool = False
+ only_cross_attention: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnets = []
+ attentions = []
+
+ for i in range(self.num_layers):
+ res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
+ resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=self.out_channels,
+ dropout_prob=self.dropout,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+
+ attn_block = FlaxTransformer2DModel(
+ in_channels=self.out_channels,
+ n_heads=self.attn_num_head_channels,
+ d_head=self.out_channels // self.attn_num_head_channels,
+ depth=1,
+ use_linear_projection=self.use_linear_projection,
+ only_cross_attention=self.only_cross_attention,
+ dtype=self.dtype,
+ )
+ attentions.append(attn_block)
+
+ self.resnets = resnets
+ self.attentions = attentions
+
+ if self.add_upsample:
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
+
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
+
+ if self.add_upsample:
+ hidden_states = self.upsamplers_0(hidden_states)
+
+ return hidden_states
+
+
+class FlaxUpBlock2D(nn.Module):
+ r"""
+ Flax 2D upsampling block
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ out_channels (:obj:`int`):
+ Output channels
+ prev_output_channel (:obj:`int`):
+ Output channels from the previous block
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of attention blocks layers
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add downsampling layer before each final output
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ out_channels: int
+ prev_output_channel: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ add_upsample: bool = True
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnets = []
+
+ for i in range(self.num_layers):
+ res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
+ resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=self.out_channels,
+ dropout_prob=self.dropout,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+
+ self.resnets = resnets
+
+ if self.add_upsample:
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
+
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
+
+ if self.add_upsample:
+ hidden_states = self.upsamplers_0(hidden_states)
+
+ return hidden_states
+
+
+class FlaxUNetMidBlock2DCrossAttn(nn.Module):
+ r"""
+ Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of attention blocks layers
+ attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
+ Number of attention heads of each spatial transformer block
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ attn_num_head_channels: int = 1
+ use_linear_projection: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ # there is always at least one resnet
+ resnets = [
+ FlaxResnetBlock2D(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ dropout_prob=self.dropout,
+ dtype=self.dtype,
+ )
+ ]
+
+ attentions = []
+
+ for _ in range(self.num_layers):
+ attn_block = FlaxTransformer2DModel(
+ in_channels=self.in_channels,
+ n_heads=self.attn_num_head_channels,
+ d_head=self.in_channels // self.attn_num_head_channels,
+ depth=1,
+ use_linear_projection=self.use_linear_projection,
+ dtype=self.dtype,
+ )
+ attentions.append(attn_block)
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ dropout_prob=self.dropout,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+
+ self.resnets = resnets
+ self.attentions = attentions
+
+ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
+
+ return hidden_states
diff --git a/diffusers/models/unet_2d_condition.py b/diffusers/models/unet_2d_condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9d3402d0619847c5d218dadb6dea080a992ab98
--- /dev/null
+++ b/diffusers/models/unet_2d_condition.py
@@ -0,0 +1,381 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput, logging
+from .embeddings import TimestepEmbedding, Timesteps
+from .unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ CrossAttnUpBlock2D,
+ DownBlock2D,
+ UNetMidBlock2DCrossAttn,
+ UpBlock2D,
+ get_down_block,
+ get_up_block,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin):
+ r"""
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
+ and returns sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the models (such as downloading or saving, etc.)
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ num_class_embeds: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ # class embedding
+ if num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift="default",
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ def set_attention_slice(self, slice_size):
+ head_dims = self.config.attention_head_dim
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a common divisor of "
+ f"the number of heads used in cross_attention: {head_dims}"
+ )
+ if slice_size is not None and slice_size > min(head_dims):
+ raise ValueError(
+ f"slice_size {slice_size} has to be smaller or equal to "
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
+ )
+
+ for block in self.down_blocks:
+ if hasattr(block, "attentions") and block.attentions is not None:
+ block.set_attention_slice(slice_size)
+
+ self.mid_block.set_attention_slice(slice_size)
+
+ for block in self.up_blocks:
+ if hasattr(block, "attentions") and block.attentions is not None:
+ block.set_attention_slice(slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+
+ if self.config.num_class_embeds is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/diffusers/models/unet_2d_condition_flax.py b/diffusers/models/unet_2d_condition_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a3f1d9e146d3ad296ae2f2bfc67d87864608d8b
--- /dev/null
+++ b/diffusers/models/unet_2d_condition_flax.py
@@ -0,0 +1,321 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Tuple, Union
+
+import flax
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict
+
+from ..configuration_utils import ConfigMixin, flax_register_to_config
+from ..modeling_flax_utils import FlaxModelMixin
+from ..utils import BaseOutput
+from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
+from .unet_2d_blocks_flax import (
+ FlaxCrossAttnDownBlock2D,
+ FlaxCrossAttnUpBlock2D,
+ FlaxDownBlock2D,
+ FlaxUNetMidBlock2DCrossAttn,
+ FlaxUpBlock2D,
+)
+
+
+@flax.struct.dataclass
+class FlaxUNet2DConditionOutput(BaseOutput):
+ """
+ Args:
+ sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: jnp.ndarray
+
+
+@flax_register_to_config
+class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
+ r"""
+ FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a
+ timestep and returns sample shaped output.
+
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the models (such as downloading or saving, etc.)
+
+ Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
+ subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
+ general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ sample_size (`int`, *optional*):
+ The size of the input sample.
+ in_channels (`int`, *optional*, defaults to 4):
+ The number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4):
+ The number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
+ "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
+ The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D",
+ "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D"
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2):
+ The number of layers per block.
+ attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
+ The dimension of the attention heads.
+ cross_attention_dim (`int`, *optional*, defaults to 768):
+ The dimension of the cross attention features.
+ dropout (`float`, *optional*, defaults to 0):
+ Dropout probability for down, up and bottleneck blocks.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+
+ """
+
+ sample_size: int = 32
+ in_channels: int = 4
+ out_channels: int = 4
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ )
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
+ only_cross_attention: Union[bool, Tuple[bool]] = False
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
+ layers_per_block: int = 2
+ attention_head_dim: Union[int, Tuple[int]] = 8
+ cross_attention_dim: int = 1280
+ dropout: float = 0.0
+ use_linear_projection: bool = False
+ dtype: jnp.dtype = jnp.float32
+ flip_sin_to_cos: bool = True
+ freq_shift: int = 0
+
+ def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
+ # init input tensors
+ sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
+ timesteps = jnp.ones((1,), dtype=jnp.int32)
+ encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
+
+ def setup(self):
+ block_out_channels = self.block_out_channels
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = nn.Conv(
+ block_out_channels[0],
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ # time
+ self.time_proj = FlaxTimesteps(
+ block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
+ )
+ self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
+
+ only_cross_attention = self.only_cross_attention
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
+
+ attention_head_dim = self.attention_head_dim
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
+
+ # down
+ down_blocks = []
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(self.down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ if down_block_type == "CrossAttnDownBlock2D":
+ down_block = FlaxCrossAttnDownBlock2D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ dropout=self.dropout,
+ num_layers=self.layers_per_block,
+ attn_num_head_channels=attention_head_dim[i],
+ add_downsample=not is_final_block,
+ use_linear_projection=self.use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ dtype=self.dtype,
+ )
+ else:
+ down_block = FlaxDownBlock2D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ dropout=self.dropout,
+ num_layers=self.layers_per_block,
+ add_downsample=not is_final_block,
+ dtype=self.dtype,
+ )
+
+ down_blocks.append(down_block)
+ self.down_blocks = down_blocks
+
+ # mid
+ self.mid_block = FlaxUNetMidBlock2DCrossAttn(
+ in_channels=block_out_channels[-1],
+ dropout=self.dropout,
+ attn_num_head_channels=attention_head_dim[-1],
+ use_linear_projection=self.use_linear_projection,
+ dtype=self.dtype,
+ )
+
+ # up
+ up_blocks = []
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(self.up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ if up_block_type == "CrossAttnUpBlock2D":
+ up_block = FlaxCrossAttnUpBlock2D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ num_layers=self.layers_per_block + 1,
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ add_upsample=not is_final_block,
+ dropout=self.dropout,
+ use_linear_projection=self.use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ dtype=self.dtype,
+ )
+ else:
+ up_block = FlaxUpBlock2D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ num_layers=self.layers_per_block + 1,
+ add_upsample=not is_final_block,
+ dropout=self.dropout,
+ dtype=self.dtype,
+ )
+
+ up_blocks.append(up_block)
+ prev_output_channel = output_channel
+ self.up_blocks = up_blocks
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
+ self.conv_out = nn.Conv(
+ self.out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ def __call__(
+ self,
+ sample,
+ timesteps,
+ encoder_hidden_states,
+ return_dict: bool = True,
+ train: bool = False,
+ ) -> Union[FlaxUNet2DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`jnp.ndarray` or `float` or `int`): timesteps
+ encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
+ plain tuple.
+ train (`bool`, *optional*, defaults to `False`):
+ Use deterministic functions and disable dropout when not training.
+
+ Returns:
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
+ When returning a tuple, the first element is the sample tensor.
+ """
+ # 1. time
+ if not isinstance(timesteps, jnp.ndarray):
+ timesteps = jnp.array([timesteps], dtype=jnp.int32)
+ elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
+ timesteps = timesteps.astype(dtype=jnp.float32)
+ timesteps = jnp.expand_dims(timesteps, 0)
+
+ t_emb = self.time_proj(timesteps)
+ t_emb = self.time_embedding(t_emb)
+
+ # 2. pre-process
+ sample = jnp.transpose(sample, (0, 2, 3, 1))
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for down_block in self.down_blocks:
+ if isinstance(down_block, FlaxCrossAttnDownBlock2D):
+ sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
+ else:
+ sample, res_samples = down_block(sample, t_emb, deterministic=not train)
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
+
+ # 5. up
+ for up_block in self.up_blocks:
+ res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
+ down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
+ if isinstance(up_block, FlaxCrossAttnUpBlock2D):
+ sample = up_block(
+ sample,
+ temb=t_emb,
+ encoder_hidden_states=encoder_hidden_states,
+ res_hidden_states_tuple=res_samples,
+ deterministic=not train,
+ )
+ else:
+ sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train)
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = nn.silu(sample)
+ sample = self.conv_out(sample)
+ sample = jnp.transpose(sample, (0, 3, 1, 2))
+
+ if not return_dict:
+ return (sample,)
+
+ return FlaxUNet2DConditionOutput(sample=sample)
diff --git a/diffusers/models/vae.py b/diffusers/models/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..e29f4e8afa2ff1cc957672b9f2d595c30a2db32e
--- /dev/null
+++ b/diffusers/models/vae.py
@@ -0,0 +1,643 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput
+from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
+
+
+@dataclass
+class DecoderOutput(BaseOutput):
+ """
+ Output of decoding method.
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Decoded output sample of the model. Output of the last layer of the model.
+ """
+
+ sample: torch.FloatTensor
+
+
+@dataclass
+class VQEncoderOutput(BaseOutput):
+ """
+ Output of VQModel encoding method.
+
+ Args:
+ latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Encoded output sample of the model. Output of the last layer of the model.
+ """
+
+ latents: torch.FloatTensor
+
+
+@dataclass
+class AutoencoderKLOutput(BaseOutput):
+ """
+ Output of AutoencoderKL encoding method.
+
+ Args:
+ latent_dist (`DiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "DiagonalGaussianDistribution"
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownEncoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ norm_num_groups=32,
+ act_fn="silu",
+ double_z=True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
+
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=self.layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-6,
+ downsample_padding=0,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attn_num_head_channels=None,
+ temb_channels=None,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=None,
+ resnet_groups=norm_num_groups,
+ temb_channels=None,
+ )
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
+
+ def forward(self, x):
+ sample = x
+ sample = self.conv_in(sample)
+
+ # down
+ for down_block in self.down_blocks:
+ sample = down_block(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ up_block_types=("UpDecoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ norm_num_groups=32,
+ act_fn="silu",
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
+
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=None,
+ resnet_groups=norm_num_groups,
+ temb_channels=None,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ prev_output_channel=None,
+ add_upsample=not is_final_block,
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attn_num_head_channels=None,
+ temb_channels=None,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ def forward(self, z):
+ sample = z
+ sample = self.conv_in(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = up_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class VectorQuantizer(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
+ multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(
+ self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
+ ):
+ super().__init__()
+ self.n_e = n_e
+ self.vq_embed_dim = vq_embed_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.vq_embed_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = (
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight**2, dim=1)
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
+ )
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
+ )
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
+ device = self.parameters.device
+ sample_device = "cpu" if device.type == "mps" else device
+ sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
+ # make sure sample is on the same device as the parameters and has same dtype
+ sample = sample.to(device=device, dtype=self.parameters.dtype)
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+class VQModel(ModelMixin, ConfigMixin):
+ r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
+ Kavukcuoglu.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(64,)`): Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): TODO
+ num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
+ vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 3,
+ sample_size: int = 32,
+ num_vq_embeddings: int = 256,
+ norm_num_groups: int = 32,
+ vq_embed_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=False,
+ )
+
+ vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
+
+ self.quant_conv = torch.nn.Conv2d(latent_channels, vq_embed_dim, 1)
+ self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
+ self.post_quant_conv = torch.nn.Conv2d(vq_embed_dim, latent_channels, 1)
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ )
+
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+
+ if not return_dict:
+ return (h,)
+
+ return VQEncoderOutput(latents=h)
+
+ def decode(
+ self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ # also go through quantization layer
+ if not force_not_quantize:
+ quant, emb_loss, info = self.quantize(h)
+ else:
+ quant = h
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ h = self.encode(x).latents
+ dec = self.decode(h).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+
+class AutoencoderKL(ModelMixin, ConfigMixin):
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
+ and Max Welling.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(64,)`): Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): TODO
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 4,
+ norm_num_groups: int = 32,
+ sample_size: int = 32,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=True,
+ )
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ )
+
+ self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+ self.use_slicing = False
+
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def enable_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
diff --git a/diffusers/models/vae_flax.py b/diffusers/models/vae_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ecda9a6e9a0eafe8c9da2abb4a9dc04948a1289
--- /dev/null
+++ b/diffusers/models/vae_flax.py
@@ -0,0 +1,858 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
+
+import math
+from functools import partial
+from typing import Tuple
+
+import flax
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict
+
+from ..configuration_utils import ConfigMixin, flax_register_to_config
+from ..modeling_flax_utils import FlaxModelMixin
+from ..utils import BaseOutput
+
+
+@flax.struct.dataclass
+class FlaxDecoderOutput(BaseOutput):
+ """
+ Output of decoding method.
+
+ Args:
+ sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
+ Decoded output sample of the model. Output of the last layer of the model.
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+
+ sample: jnp.ndarray
+
+
+@flax.struct.dataclass
+class FlaxAutoencoderKLOutput(BaseOutput):
+ """
+ Output of AutoencoderKL encoding method.
+
+ Args:
+ latent_dist (`FlaxDiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
+ `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "FlaxDiagonalGaussianDistribution"
+
+
+class FlaxUpsample2D(nn.Module):
+ """
+ Flax implementation of 2D Upsample layer
+
+ Args:
+ in_channels (`int`):
+ Input channels
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+
+ in_channels: int
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.conv = nn.Conv(
+ self.in_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states):
+ batch, height, width, channels = hidden_states.shape
+ hidden_states = jax.image.resize(
+ hidden_states,
+ shape=(batch, height * 2, width * 2, channels),
+ method="nearest",
+ )
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class FlaxDownsample2D(nn.Module):
+ """
+ Flax implementation of 2D Downsample layer
+
+ Args:
+ in_channels (`int`):
+ Input channels
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+
+ in_channels: int
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.conv = nn.Conv(
+ self.in_channels,
+ kernel_size=(3, 3),
+ strides=(2, 2),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states):
+ pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
+ hidden_states = jnp.pad(hidden_states, pad_width=pad)
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class FlaxResnetBlock2D(nn.Module):
+ """
+ Flax implementation of 2D Resnet Block.
+
+ Args:
+ in_channels (`int`):
+ Input channels
+ out_channels (`int`):
+ Output channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for group norm.
+ use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
+ Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+
+ in_channels: int
+ out_channels: int = None
+ dropout: float = 0.0
+ groups: int = 32
+ use_nin_shortcut: bool = None
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ out_channels = self.in_channels if self.out_channels is None else self.out_channels
+
+ self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
+ self.conv1 = nn.Conv(
+ out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
+ self.dropout_layer = nn.Dropout(self.dropout)
+ self.conv2 = nn.Conv(
+ out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
+
+ self.conv_shortcut = None
+ if use_nin_shortcut:
+ self.conv_shortcut = nn.Conv(
+ out_channels,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states, deterministic=True):
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = nn.swish(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = nn.swish(hidden_states)
+ hidden_states = self.dropout_layer(hidden_states, deterministic)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ residual = self.conv_shortcut(residual)
+
+ return hidden_states + residual
+
+
+class FlaxAttentionBlock(nn.Module):
+ r"""
+ Flax Convolutional based multi-head attention block for diffusion-based VAE.
+
+ Parameters:
+ channels (:obj:`int`):
+ Input channels
+ num_head_channels (:obj:`int`, *optional*, defaults to `None`):
+ Number of attention heads
+ num_groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for group norm
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+
+ """
+ channels: int
+ num_head_channels: int = None
+ num_groups: int = 32
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
+
+ dense = partial(nn.Dense, self.channels, dtype=self.dtype)
+
+ self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
+ self.query, self.key, self.value = dense(), dense(), dense()
+ self.proj_attn = dense()
+
+ def transpose_for_scores(self, projection):
+ new_projection_shape = projection.shape[:-1] + (self.num_heads, -1)
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D)
+ new_projection = projection.reshape(new_projection_shape)
+ # (B, T, H, D) -> (B, H, T, D)
+ new_projection = jnp.transpose(new_projection, (0, 2, 1, 3))
+ return new_projection
+
+ def __call__(self, hidden_states):
+ residual = hidden_states
+ batch, height, width, channels = hidden_states.shape
+
+ hidden_states = self.group_norm(hidden_states)
+
+ hidden_states = hidden_states.reshape((batch, height * width, channels))
+
+ query = self.query(hidden_states)
+ key = self.key(hidden_states)
+ value = self.value(hidden_states)
+
+ # transpose
+ query = self.transpose_for_scores(query)
+ key = self.transpose_for_scores(key)
+ value = self.transpose_for_scores(value)
+
+ # compute attentions
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
+ attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale)
+ attn_weights = nn.softmax(attn_weights, axis=-1)
+
+ # attend to values
+ hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
+
+ hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3))
+ new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,)
+ hidden_states = hidden_states.reshape(new_hidden_states_shape)
+
+ hidden_states = self.proj_attn(hidden_states)
+ hidden_states = hidden_states.reshape((batch, height, width, channels))
+ hidden_states = hidden_states + residual
+ return hidden_states
+
+
+class FlaxDownEncoderBlock2D(nn.Module):
+ r"""
+ Flax Resnet blocks-based Encoder block for diffusion-based VAE.
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ out_channels (:obj:`int`):
+ Output channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of Resnet layer block
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for the Resnet block group norm
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add downsample layer
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ out_channels: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ resnet_groups: int = 32
+ add_downsample: bool = True
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnets = []
+ for i in range(self.num_layers):
+ in_channels = self.in_channels if i == 0 else self.out_channels
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ dropout=self.dropout,
+ groups=self.resnet_groups,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+ self.resnets = resnets
+
+ if self.add_downsample:
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, deterministic=True):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, deterministic=deterministic)
+
+ if self.add_downsample:
+ hidden_states = self.downsamplers_0(hidden_states)
+
+ return hidden_states
+
+
+class FlaxUpDecoderBlock2D(nn.Module):
+ r"""
+ Flax Resnet blocks-based Decoder block for diffusion-based VAE.
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ out_channels (:obj:`int`):
+ Output channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of Resnet layer block
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for the Resnet block group norm
+ add_upsample (:obj:`bool`, *optional*, defaults to `True`):
+ Whether to add upsample layer
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ out_channels: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ resnet_groups: int = 32
+ add_upsample: bool = True
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnets = []
+ for i in range(self.num_layers):
+ in_channels = self.in_channels if i == 0 else self.out_channels
+ res_block = FlaxResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ dropout=self.dropout,
+ groups=self.resnet_groups,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+
+ self.resnets = resnets
+
+ if self.add_upsample:
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, deterministic=True):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, deterministic=deterministic)
+
+ if self.add_upsample:
+ hidden_states = self.upsamplers_0(hidden_states)
+
+ return hidden_states
+
+
+class FlaxUNetMidBlock2D(nn.Module):
+ r"""
+ Flax Unet Mid-Block module.
+
+ Parameters:
+ in_channels (:obj:`int`):
+ Input channels
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
+ Dropout rate
+ num_layers (:obj:`int`, *optional*, defaults to 1):
+ Number of Resnet layer block
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
+ The number of groups to use for the Resnet and Attention block group norm
+ attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
+ Number of attention heads for each attention block
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int
+ dropout: float = 0.0
+ num_layers: int = 1
+ resnet_groups: int = 32
+ attn_num_head_channels: int = 1
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ FlaxResnetBlock2D(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ dropout=self.dropout,
+ groups=resnet_groups,
+ dtype=self.dtype,
+ )
+ ]
+
+ attentions = []
+
+ for _ in range(self.num_layers):
+ attn_block = FlaxAttentionBlock(
+ channels=self.in_channels,
+ num_head_channels=self.attn_num_head_channels,
+ num_groups=resnet_groups,
+ dtype=self.dtype,
+ )
+ attentions.append(attn_block)
+
+ res_block = FlaxResnetBlock2D(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ dropout=self.dropout,
+ groups=resnet_groups,
+ dtype=self.dtype,
+ )
+ resnets.append(res_block)
+
+ self.resnets = resnets
+ self.attentions = attentions
+
+ def __call__(self, hidden_states, deterministic=True):
+ hidden_states = self.resnets[0](hidden_states, deterministic=deterministic)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states)
+ hidden_states = resnet(hidden_states, deterministic=deterministic)
+
+ return hidden_states
+
+
+class FlaxEncoder(nn.Module):
+ r"""
+ Flax Implementation of VAE Encoder.
+
+ This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
+ subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
+ general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ in_channels (:obj:`int`, *optional*, defaults to 3):
+ Input channels
+ out_channels (:obj:`int`, *optional*, defaults to 3):
+ Output channels
+ down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
+ DownEncoder block type
+ block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
+ Tuple containing the number of output channels for each block
+ layers_per_block (:obj:`int`, *optional*, defaults to `2`):
+ Number of Resnet layer for each block
+ norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
+ norm num group
+ act_fn (:obj:`str`, *optional*, defaults to `silu`):
+ Activation function
+ double_z (:obj:`bool`, *optional*, defaults to `False`):
+ Whether to double the last output channels
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ Parameters `dtype`
+ """
+ in_channels: int = 3
+ out_channels: int = 3
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
+ block_out_channels: Tuple[int] = (64,)
+ layers_per_block: int = 2
+ norm_num_groups: int = 32
+ act_fn: str = "silu"
+ double_z: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ block_out_channels = self.block_out_channels
+ # in
+ self.conv_in = nn.Conv(
+ block_out_channels[0],
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ # downsampling
+ down_blocks = []
+ output_channel = block_out_channels[0]
+ for i, _ in enumerate(self.down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = FlaxDownEncoderBlock2D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ num_layers=self.layers_per_block,
+ resnet_groups=self.norm_num_groups,
+ add_downsample=not is_final_block,
+ dtype=self.dtype,
+ )
+ down_blocks.append(down_block)
+ self.down_blocks = down_blocks
+
+ # middle
+ self.mid_block = FlaxUNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_groups=self.norm_num_groups,
+ attn_num_head_channels=None,
+ dtype=self.dtype,
+ )
+
+ # end
+ conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
+ self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
+ self.conv_out = nn.Conv(
+ conv_out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ def __call__(self, sample, deterministic: bool = True):
+ # in
+ sample = self.conv_in(sample)
+
+ # downsampling
+ for block in self.down_blocks:
+ sample = block(sample, deterministic=deterministic)
+
+ # middle
+ sample = self.mid_block(sample, deterministic=deterministic)
+
+ # end
+ sample = self.conv_norm_out(sample)
+ sample = nn.swish(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class FlaxDecoder(nn.Module):
+ r"""
+ Flax Implementation of VAE Decoder.
+
+ This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
+ subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
+ general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ in_channels (:obj:`int`, *optional*, defaults to 3):
+ Input channels
+ out_channels (:obj:`int`, *optional*, defaults to 3):
+ Output channels
+ up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
+ UpDecoder block type
+ block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
+ Tuple containing the number of output channels for each block
+ layers_per_block (:obj:`int`, *optional*, defaults to `2`):
+ Number of Resnet layer for each block
+ norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
+ norm num group
+ act_fn (:obj:`str`, *optional*, defaults to `silu`):
+ Activation function
+ double_z (:obj:`bool`, *optional*, defaults to `False`):
+ Whether to double the last output channels
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ parameters `dtype`
+ """
+ in_channels: int = 3
+ out_channels: int = 3
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
+ block_out_channels: int = (64,)
+ layers_per_block: int = 2
+ norm_num_groups: int = 32
+ act_fn: str = "silu"
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ block_out_channels = self.block_out_channels
+
+ # z to block_in
+ self.conv_in = nn.Conv(
+ block_out_channels[-1],
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ # middle
+ self.mid_block = FlaxUNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_groups=self.norm_num_groups,
+ attn_num_head_channels=None,
+ dtype=self.dtype,
+ )
+
+ # upsampling
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ up_blocks = []
+ for i, _ in enumerate(self.up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = FlaxUpDecoderBlock2D(
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ num_layers=self.layers_per_block + 1,
+ resnet_groups=self.norm_num_groups,
+ add_upsample=not is_final_block,
+ dtype=self.dtype,
+ )
+ up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ self.up_blocks = up_blocks
+
+ # end
+ self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
+ self.conv_out = nn.Conv(
+ self.out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding=((1, 1), (1, 1)),
+ dtype=self.dtype,
+ )
+
+ def __call__(self, sample, deterministic: bool = True):
+ # z to block_in
+ sample = self.conv_in(sample)
+
+ # middle
+ sample = self.mid_block(sample, deterministic=deterministic)
+
+ # upsampling
+ for block in self.up_blocks:
+ sample = block(sample, deterministic=deterministic)
+
+ sample = self.conv_norm_out(sample)
+ sample = nn.swish(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class FlaxDiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ # Last axis to account for channels-last
+ self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
+ self.logvar = jnp.clip(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = jnp.exp(0.5 * self.logvar)
+ self.var = jnp.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = jnp.zeros_like(self.mean)
+
+ def sample(self, key):
+ return self.mean + self.std * jax.random.normal(key, self.mean.shape)
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return jnp.array([0.0])
+
+ if other is None:
+ return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3])
+
+ return 0.5 * jnp.sum(
+ jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ axis=[1, 2, 3],
+ )
+
+ def nll(self, sample, axis=[1, 2, 3]):
+ if self.deterministic:
+ return jnp.array([0.0])
+
+ logtwopi = jnp.log(2.0 * jnp.pi)
+ return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis)
+
+ def mode(self):
+ return self.mean
+
+
+@flax_register_to_config
+class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
+ r"""
+ Flax Implementation of Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational
+ Bayes by Diederik P. Kingma and Max Welling.
+
+ This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
+ subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
+ general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ in_channels (:obj:`int`, *optional*, defaults to 3):
+ Input channels
+ out_channels (:obj:`int`, *optional*, defaults to 3):
+ Output channels
+ down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
+ DownEncoder block type
+ up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
+ UpDecoder block type
+ block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
+ Tuple containing the number of output channels for each block
+ layers_per_block (:obj:`int`, *optional*, defaults to `2`):
+ Number of Resnet layer for each block
+ act_fn (:obj:`str`, *optional*, defaults to `silu`):
+ Activation function
+ latent_channels (:obj:`int`, *optional*, defaults to `4`):
+ Latent space channels
+ norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
+ Norm num group
+ sample_size (:obj:`int`, *optional*, defaults to `32`):
+ Sample input size
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
+ parameters `dtype`
+ """
+ in_channels: int = 3
+ out_channels: int = 3
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
+ block_out_channels: Tuple[int] = (64,)
+ layers_per_block: int = 1
+ act_fn: str = "silu"
+ latent_channels: int = 4
+ norm_num_groups: int = 32
+ sample_size: int = 32
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.encoder = FlaxEncoder(
+ in_channels=self.config.in_channels,
+ out_channels=self.config.latent_channels,
+ down_block_types=self.config.down_block_types,
+ block_out_channels=self.config.block_out_channels,
+ layers_per_block=self.config.layers_per_block,
+ act_fn=self.config.act_fn,
+ norm_num_groups=self.config.norm_num_groups,
+ double_z=True,
+ dtype=self.dtype,
+ )
+ self.decoder = FlaxDecoder(
+ in_channels=self.config.latent_channels,
+ out_channels=self.config.out_channels,
+ up_block_types=self.config.up_block_types,
+ block_out_channels=self.config.block_out_channels,
+ layers_per_block=self.config.layers_per_block,
+ norm_num_groups=self.config.norm_num_groups,
+ act_fn=self.config.act_fn,
+ dtype=self.dtype,
+ )
+ self.quant_conv = nn.Conv(
+ 2 * self.config.latent_channels,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+ self.post_quant_conv = nn.Conv(
+ self.config.latent_channels,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding="VALID",
+ dtype=self.dtype,
+ )
+
+ def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
+ # init input tensors
+ sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
+
+ params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3)
+ rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng}
+
+ return self.init(rngs, sample)["params"]
+
+ def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
+ sample = jnp.transpose(sample, (0, 2, 3, 1))
+
+ hidden_states = self.encoder(sample, deterministic=deterministic)
+ moments = self.quant_conv(hidden_states)
+ posterior = FlaxDiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return FlaxAutoencoderKLOutput(latent_dist=posterior)
+
+ def decode(self, latents, deterministic: bool = True, return_dict: bool = True):
+ if latents.shape[-1] != self.config.latent_channels:
+ latents = jnp.transpose(latents, (0, 2, 3, 1))
+
+ hidden_states = self.post_quant_conv(latents)
+ hidden_states = self.decoder(hidden_states, deterministic=deterministic)
+
+ hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2))
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return FlaxDecoderOutput(sample=hidden_states)
+
+ def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True):
+ posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict)
+ if sample_posterior:
+ rng = self.make_rng("gaussian")
+ hidden_states = posterior.latent_dist.sample(rng)
+ else:
+ hidden_states = posterior.latent_dist.mode()
+
+ sample = self.decode(hidden_states, return_dict=return_dict).sample
+
+ if not return_dict:
+ return (sample,)
+
+ return FlaxDecoderOutput(sample=sample)
diff --git a/diffusers/onnx_utils.py b/diffusers/onnx_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2c533ed741f213c28df8d917702e8400a199443
--- /dev/null
+++ b/diffusers/onnx_utils.py
@@ -0,0 +1,213 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+from pathlib import Path
+from typing import Optional, Union
+
+import numpy as np
+
+from huggingface_hub import hf_hub_download
+
+from .utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging
+
+
+if is_onnx_available():
+ import onnxruntime as ort
+
+
+logger = logging.get_logger(__name__)
+
+ORT_TO_NP_TYPE = {
+ "tensor(bool)": np.bool_,
+ "tensor(int8)": np.int8,
+ "tensor(uint8)": np.uint8,
+ "tensor(int16)": np.int16,
+ "tensor(uint16)": np.uint16,
+ "tensor(int32)": np.int32,
+ "tensor(uint32)": np.uint32,
+ "tensor(int64)": np.int64,
+ "tensor(uint64)": np.uint64,
+ "tensor(float16)": np.float16,
+ "tensor(float)": np.float32,
+ "tensor(double)": np.float64,
+}
+
+
+class OnnxRuntimeModel:
+ def __init__(self, model=None, **kwargs):
+ logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
+ self.model = model
+ self.model_save_dir = kwargs.get("model_save_dir", None)
+ self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME)
+
+ def __call__(self, **kwargs):
+ inputs = {k: np.array(v) for k, v in kwargs.items()}
+ return self.model.run(None, inputs)
+
+ @staticmethod
+ def load_model(path: Union[str, Path], provider=None, sess_options=None):
+ """
+ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
+
+ Arguments:
+ path (`str` or `Path`):
+ Directory from which to load
+ provider(`str`, *optional*):
+ Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
+ """
+ if provider is None:
+ logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
+ provider = "CPUExecutionProvider"
+
+ return ort.InferenceSession(path, providers=[provider], sess_options=sess_options)
+
+ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the
+ latest_model_name.
+
+ Arguments:
+ save_directory (`str` or `Path`):
+ Directory where to save the model file.
+ file_name(`str`, *optional*):
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the
+ model with a different name.
+ """
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
+
+ src_path = self.model_save_dir.joinpath(self.latest_model_name)
+ dst_path = Path(save_directory).joinpath(model_file_name)
+ try:
+ shutil.copyfile(src_path, dst_path)
+ except shutil.SameFileError:
+ pass
+
+ # copy external weights (for models >2GB)
+ src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
+ if src_path.exists():
+ dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
+ try:
+ shutil.copyfile(src_path, dst_path)
+ except shutil.SameFileError:
+ pass
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ **kwargs,
+ ):
+ """
+ Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class
+ method.:
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # saving model weights/files
+ self._save_pretrained(save_directory, **kwargs)
+
+ @classmethod
+ def _from_pretrained(
+ cls,
+ model_id: Union[str, Path],
+ use_auth_token: Optional[Union[bool, str, None]] = None,
+ revision: Optional[Union[str, None]] = None,
+ force_download: bool = False,
+ cache_dir: Optional[str] = None,
+ file_name: Optional[str] = None,
+ provider: Optional[str] = None,
+ sess_options: Optional["ort.SessionOptions"] = None,
+ **kwargs,
+ ):
+ """
+ Load a model from a directory or the HF Hub.
+
+ Arguments:
+ model_id (`str` or `Path`):
+ Directory from which to load
+ use_auth_token (`str` or `bool`):
+ Is needed to load models from a private or gated repository
+ revision (`str`):
+ Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
+ cache_dir (`Union[str, Path]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ file_name(`str`):
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load
+ different model files from the same repository or directory.
+ provider(`str`):
+ The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`.
+ kwargs (`Dict`, *optional*):
+ kwargs will be passed to the model during initialization
+ """
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
+ # load model from local directory
+ if os.path.isdir(model_id):
+ model = OnnxRuntimeModel.load_model(
+ os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
+ )
+ kwargs["model_save_dir"] = Path(model_id)
+ # load model from hub
+ else:
+ # download model
+ model_cache_path = hf_hub_download(
+ repo_id=model_id,
+ filename=model_file_name,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ )
+ kwargs["model_save_dir"] = Path(model_cache_path).parent
+ kwargs["latest_model_name"] = Path(model_cache_path).name
+ model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
+ return cls(model=model, **kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_id: Union[str, Path],
+ force_download: bool = True,
+ use_auth_token: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ **model_kwargs,
+ ):
+ revision = None
+ if len(str(model_id).split("@")) == 2:
+ model_id, revision = model_id.split("@")
+
+ return cls._from_pretrained(
+ model_id=model_id,
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ use_auth_token=use_auth_token,
+ **model_kwargs,
+ )
diff --git a/diffusers/optimization.py b/diffusers/optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7b836b4a69bffb61c15967ef9b1736201721f1b
--- /dev/null
+++ b/diffusers/optimization.py
@@ -0,0 +1,275 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch optimization for diffusion models."""
+
+import math
+from enum import Enum
+from typing import Optional, Union
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LambdaLR
+
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class SchedulerType(Enum):
+ LINEAR = "linear"
+ COSINE = "cosine"
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
+ POLYNOMIAL = "polynomial"
+ CONSTANT = "constant"
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
+
+
+def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
+
+
+def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+ increases linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1.0, num_warmup_steps))
+ return 1.0
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
+
+
+def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
+ """
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ return max(
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+ )
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`float`, *optional*, defaults to 0.5):
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+ following a half-cosine).
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
+ linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`int`, *optional*, defaults to 1):
+ The number of hard restarts to use.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ if progress >= 1.0:
+ return 0.0
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_polynomial_decay_schedule_with_warmup(
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
+):
+ """
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ lr_end (`float`, *optional*, defaults to 1e-7):
+ The end LR.
+ power (`float`, *optional*, defaults to 1.0):
+ Power factor.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
+ implementation at
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+
+ """
+
+ lr_init = optimizer.defaults["lr"]
+ if not (lr_init > lr_end):
+ raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ elif current_step > num_training_steps:
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
+ else:
+ lr_range = lr_init - lr_end
+ decay_steps = num_training_steps - num_warmup_steps
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
+ decay = lr_range * pct_remaining**power + lr_end
+ return decay / lr_init # as LambdaLR multiplies by lr_init
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+TYPE_TO_SCHEDULER_FUNCTION = {
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
+ SchedulerType.CONSTANT: get_constant_schedule,
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
+}
+
+
+def get_scheduler(
+ name: Union[str, SchedulerType],
+ optimizer: Optimizer,
+ num_warmup_steps: Optional[int] = None,
+ num_training_steps: Optional[int] = None,
+):
+ """
+ Unified API to get any scheduler from its name.
+
+ Args:
+ name (`str` or `SchedulerType`):
+ The name of the scheduler to use.
+ optimizer (`torch.optim.Optimizer`):
+ The optimizer that will be used during training.
+ num_warmup_steps (`int`, *optional*):
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ num_training_steps (`int``, *optional*):
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ """
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+ if name == SchedulerType.CONSTANT:
+ return schedule_func(optimizer)
+
+ # All other schedulers require `num_warmup_steps`
+ if num_warmup_steps is None:
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
+
+ # All other schedulers require `num_training_steps`
+ if num_training_steps is None:
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
diff --git a/diffusers/pipeline_flax_utils.py b/diffusers/pipeline_flax_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8fd304776d785fc11df5f93f5c2443fb78bceef
--- /dev/null
+++ b/diffusers/pipeline_flax_utils.py
@@ -0,0 +1,506 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import inspect
+import os
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+import flax
+import PIL
+from flax.core.frozen_dict import FrozenDict
+from huggingface_hub import snapshot_download
+from PIL import Image
+from tqdm.auto import tqdm
+
+from .configuration_utils import ConfigMixin
+from .hub_utils import http_user_agent
+from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
+from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
+from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
+
+
+if is_transformers_available():
+ from transformers import FlaxPreTrainedModel
+
+INDEX_FILE = "diffusion_flax_model.bin"
+
+
+logger = logging.get_logger(__name__)
+
+
+LOADABLE_CLASSES = {
+ "diffusers": {
+ "FlaxModelMixin": ["save_pretrained", "from_pretrained"],
+ "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
+ "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
+ },
+ "transformers": {
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
+ "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"],
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
+ "ProcessorMixin": ["save_pretrained", "from_pretrained"],
+ "ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
+ },
+}
+
+ALL_IMPORTABLE_CLASSES = {}
+for library in LOADABLE_CLASSES:
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
+
+
+def import_flax_or_no_model(module, class_name):
+ try:
+ # 1. First make sure that if a Flax object is present, import this one
+ class_obj = getattr(module, "Flax" + class_name)
+ except AttributeError:
+ # 2. If this doesn't work, it's not a model and we don't append "Flax"
+ class_obj = getattr(module, class_name)
+ except AttributeError:
+ raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}")
+
+ return class_obj
+
+
+@flax.struct.dataclass
+class FlaxImagePipelineOutput(BaseOutput):
+ """
+ Output class for image pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+class FlaxDiffusionPipeline(ConfigMixin):
+ r"""
+ Base class for all models.
+
+ [`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion
+ pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all
+ pipelines to:
+
+ - enabling/disabling the progress bar for the denoising iteration
+
+ Class attributes:
+
+ - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
+ components of the diffusion pipeline.
+ """
+ config_name = "model_index.json"
+
+ def register_modules(self, **kwargs):
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ for name, module in kwargs.items():
+ if module is None:
+ register_dict = {name: (None, None)}
+ else:
+ # retrieve library
+ library = module.__module__.split(".")[0]
+
+ # check if the module is a pipeline module
+ pipeline_dir = module.__module__.split(".")[-2]
+ path = module.__module__.split(".")
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
+
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
+ # Or if it's a pipeline module, then the module is inside the pipeline
+ # folder so we set the library to module name.
+ if library not in LOADABLE_CLASSES or is_pipeline_module:
+ library = pipeline_dir
+
+ # retrieve class_name
+ class_name = module.__class__.__name__
+
+ register_dict = {name: (library, class_name)}
+
+ # save model index config
+ self.register_to_config(**register_dict)
+
+ # set models
+ setattr(self, name, module)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]):
+ # TODO: handle inference_state
+ """
+ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
+ a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
+ method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class
+ method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ """
+ self.save_config(save_directory)
+
+ model_index_dict = dict(self.config)
+ model_index_dict.pop("_class_name")
+ model_index_dict.pop("_diffusers_version")
+ model_index_dict.pop("_module", None)
+
+ for pipeline_component_name in model_index_dict.keys():
+ sub_model = getattr(self, pipeline_component_name)
+ if sub_model is None:
+ # edge case for saving a pipeline with safety_checker=None
+ continue
+
+ model_cls = sub_model.__class__
+
+ save_method_name = None
+ # search for the model's base class in LOADABLE_CLASSES
+ for library_name, library_classes in LOADABLE_CLASSES.items():
+ library = importlib.import_module(library_name)
+ for base_class, save_load_methods in library_classes.items():
+ class_candidate = getattr(library, base_class, None)
+ if class_candidate is not None and issubclass(model_cls, class_candidate):
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
+ save_method_name = save_load_methods[0]
+ break
+ if save_method_name is not None:
+ break
+
+ save_method = getattr(sub_model, save_method_name)
+ expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
+
+ if expects_params:
+ save_method(
+ os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name]
+ )
+ else:
+ save_method(os.path.join(save_directory, pipeline_component_name))
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a Flax diffusion pipeline from pre-trained pipeline weights.
+
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
+ https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
+ `CompVis/ldm-text2im-large-256`.
+ - A path to a *directory* containing pipeline weights saved using
+ [`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
+ dtype (`str` or `jnp.dtype`, *optional*):
+ Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information. specify the folder name here.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
+ specific pipeline class. The overwritten components are then directly passed to the pipelines
+ `__init__` method. See example below for more information.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ Examples:
+
+ ```py
+ >>> from diffusers import FlaxDiffusionPipeline
+
+ >>> # Download pipeline from huggingface.co and cache.
+ >>> # Requires to be logged in to Hugging Face hub,
+ >>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens)
+ >>> pipeline, params = FlaxDiffusionPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5",
+ ... revision="bf16",
+ ... dtype=jnp.bfloat16,
+ ... )
+
+ >>> # Download pipeline, but use a different scheduler
+ >>> from diffusers import FlaxDPMSolverMultistepScheduler
+
+ >>> model_id = "runwayml/stable-diffusion-v1-5"
+ >>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
+ ... model_id,
+ ... subfolder="scheduler",
+ ... )
+
+ >>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained(
+ ... model_id, revision="bf16", dtype=jnp.bfloat16, scheduler=dpmpp
+ ... )
+ >>> dpm_params["scheduler"] = dpmpp_state
+ ```
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ from_pt = kwargs.pop("from_pt", False)
+ dtype = kwargs.pop("dtype", None)
+
+ # 1. Download the checkpoints and configs
+ # use snapshot download here to get it working from from_pretrained
+ if not os.path.isdir(pretrained_model_name_or_path):
+ config_dict = cls.load_config(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ # make sure we only download sub-folders and `diffusers` filenames
+ folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
+ allow_patterns = [os.path.join(k, "*") for k in folder_names]
+ allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
+
+ # make sure we don't download PyTorch weights, unless when using from_pt
+ ignore_patterns = "*.bin" if not from_pt else []
+
+ if cls != FlaxDiffusionPipeline:
+ requested_pipeline_class = cls.__name__
+ else:
+ requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
+ requested_pipeline_class = (
+ requested_pipeline_class
+ if requested_pipeline_class.startswith("Flax")
+ else "Flax" + requested_pipeline_class
+ )
+
+ user_agent = {"pipeline_class": requested_pipeline_class}
+ user_agent = http_user_agent(user_agent)
+
+ # download all allow_patterns
+ cached_folder = snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ allow_patterns=allow_patterns,
+ ignore_patterns=ignore_patterns,
+ user_agent=user_agent,
+ )
+ else:
+ cached_folder = pretrained_model_name_or_path
+
+ config_dict = cls.load_config(cached_folder)
+
+ # 2. Load the pipeline class, if using custom module then load it from the hub
+ # if we load from explicit class, let's use it
+ if cls != FlaxDiffusionPipeline:
+ pipeline_class = cls
+ else:
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
+ class_name = (
+ config_dict["_class_name"]
+ if config_dict["_class_name"].startswith("Flax")
+ else "Flax" + config_dict["_class_name"]
+ )
+ pipeline_class = getattr(diffusers_module, class_name)
+
+ # some modules can be passed directly to the init
+ # in this case they are already instantiated in `kwargs`
+ # extract them here
+ expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+
+ init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
+
+ init_kwargs = {}
+
+ # inference_params
+ params = {}
+
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ # 3. Load each module in the pipeline
+ for name, (library_name, class_name) in init_dict.items():
+ if class_name is None:
+ # edge case for when the pipeline was saved with safety_checker=None
+ init_kwargs[name] = None
+ continue
+
+ is_pipeline_module = hasattr(pipelines, library_name)
+ loaded_sub_model = None
+ sub_model_should_be_defined = True
+
+ # if the model is in a pipeline module, then we load it from the pipeline
+ if name in passed_class_obj:
+ # 1. check that passed_class_obj has correct parent class
+ if not is_pipeline_module:
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
+
+ expected_class_obj = None
+ for class_name, class_candidate in class_candidates.items():
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
+ expected_class_obj = class_candidate
+
+ if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
+ raise ValueError(
+ f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
+ f" {expected_class_obj}"
+ )
+ elif passed_class_obj[name] is None:
+ logger.warning(
+ f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
+ f" that this might lead to problems when using {pipeline_class} and is not recommended."
+ )
+ sub_model_should_be_defined = False
+ else:
+ logger.warning(
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
+ " has the correct type"
+ )
+
+ # set passed class object
+ loaded_sub_model = passed_class_obj[name]
+ elif is_pipeline_module:
+ pipeline_module = getattr(pipelines, library_name)
+ class_obj = import_flax_or_no_model(pipeline_module, class_name)
+
+ importable_classes = ALL_IMPORTABLE_CLASSES
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
+ else:
+ # else we just import it from the library.
+ library = importlib.import_module(library_name)
+ class_obj = import_flax_or_no_model(library, class_name)
+
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
+
+ if loaded_sub_model is None and sub_model_should_be_defined:
+ load_method_name = None
+ for class_name, class_candidate in class_candidates.items():
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
+ load_method_name = importable_classes[class_name][1]
+
+ load_method = getattr(class_obj, load_method_name)
+
+ # check if the module is in a subdirectory
+ if os.path.isdir(os.path.join(cached_folder, name)):
+ loadable_folder = os.path.join(cached_folder, name)
+ else:
+ loaded_sub_model = cached_folder
+
+ if issubclass(class_obj, FlaxModelMixin):
+ loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
+ params[name] = loaded_params
+ elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
+ if from_pt:
+ # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
+ loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
+ loaded_params = loaded_sub_model.params
+ del loaded_sub_model._params
+ else:
+ loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
+ params[name] = loaded_params
+ elif issubclass(class_obj, FlaxSchedulerMixin):
+ loaded_sub_model, scheduler_state = load_method(loadable_folder)
+ params[name] = scheduler_state
+ else:
+ loaded_sub_model = load_method(loadable_folder)
+
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
+
+ model = pipeline_class(**init_kwargs, dtype=dtype)
+ return model, params
+
+ @staticmethod
+ def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ if images.shape[-1] == 1:
+ # special case for grayscale (single channel) images
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
+ else:
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+ # TODO: make it compatible with jax.lax
+ def progress_bar(self, iterable):
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ elif not isinstance(self._progress_bar_config, dict):
+ raise ValueError(
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
+ )
+
+ return tqdm(iterable, **self._progress_bar_config)
+
+ def set_progress_bar_config(self, **kwargs):
+ self._progress_bar_config = kwargs
diff --git a/diffusers/pipeline_utils.py b/diffusers/pipeline_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e65d55e20cd9faa5396ed116efcc28656079e972
--- /dev/null
+++ b/diffusers/pipeline_utils.py
@@ -0,0 +1,841 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import inspect
+import os
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+
+import diffusers
+import PIL
+from huggingface_hub import model_info, snapshot_download
+from packaging import version
+from PIL import Image
+from tqdm.auto import tqdm
+
+from .configuration_utils import ConfigMixin
+from .dynamic_modules_utils import get_class_from_dynamic_module
+from .hub_utils import http_user_agent
+from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
+from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
+from .utils import (
+ CONFIG_NAME,
+ DIFFUSERS_CACHE,
+ ONNX_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ BaseOutput,
+ deprecate,
+ is_accelerate_available,
+ is_safetensors_available,
+ is_torch_version,
+ is_transformers_available,
+ logging,
+)
+
+
+if is_transformers_available():
+ import transformers
+ from transformers import PreTrainedModel
+
+
+INDEX_FILE = "diffusion_pytorch_model.bin"
+CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
+DUMMY_MODULES_FOLDER = "diffusers.utils"
+TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
+
+
+logger = logging.get_logger(__name__)
+
+
+LOADABLE_CLASSES = {
+ "diffusers": {
+ "ModelMixin": ["save_pretrained", "from_pretrained"],
+ "SchedulerMixin": ["save_pretrained", "from_pretrained"],
+ "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
+ "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
+ },
+ "transformers": {
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
+ "PreTrainedModel": ["save_pretrained", "from_pretrained"],
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
+ "ProcessorMixin": ["save_pretrained", "from_pretrained"],
+ "ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
+ },
+ "onnxruntime.training": {
+ "ORTModule": ["save_pretrained", "from_pretrained"],
+ },
+}
+
+ALL_IMPORTABLE_CLASSES = {}
+for library in LOADABLE_CLASSES:
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
+
+
+@dataclass
+class ImagePipelineOutput(BaseOutput):
+ """
+ Output class for image pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+@dataclass
+class AudioPipelineOutput(BaseOutput):
+ """
+ Output class for audio pipelines.
+
+ Args:
+ audios (`np.ndarray`)
+ List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the
+ denoised audio samples of the diffusion pipeline.
+ """
+
+ audios: np.ndarray
+
+
+def is_safetensors_compatible(info) -> bool:
+ filenames = set(sibling.rfilename for sibling in info.siblings)
+ pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
+ is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
+ for pt_filename in pt_filenames:
+ prefix, raw = os.path.split(pt_filename)
+ if raw == "pytorch_model.bin":
+ # transformers specific
+ sf_filename = os.path.join(prefix, "model.safetensors")
+ else:
+ sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
+ if is_safetensors_compatible and sf_filename not in filenames:
+ logger.warning(f"{sf_filename} not found")
+ is_safetensors_compatible = False
+ return is_safetensors_compatible
+
+
+class DiffusionPipeline(ConfigMixin):
+ r"""
+ Base class for all models.
+
+ [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
+ and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
+
+ - move all PyTorch modules to the device of your choice
+ - enabling/disabling the progress bar for the denoising iteration
+
+ Class attributes:
+
+ - **config_name** (`str`) -- name of the config file that will store the class and module names of all
+ components of the diffusion pipeline.
+ - **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
+ passed for the pipeline to function (should be overridden by subclasses).
+ """
+ config_name = "model_index.json"
+ _optional_components = []
+
+ def register_modules(self, **kwargs):
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ for name, module in kwargs.items():
+ # retrieve library
+ if module is None:
+ register_dict = {name: (None, None)}
+ else:
+ library = module.__module__.split(".")[0]
+
+ # check if the module is a pipeline module
+ pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
+ path = module.__module__.split(".")
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
+
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
+ # Or if it's a pipeline module, then the module is inside the pipeline
+ # folder so we set the library to module name.
+ if library not in LOADABLE_CLASSES or is_pipeline_module:
+ library = pipeline_dir
+
+ # retrieve class_name
+ class_name = module.__class__.__name__
+
+ register_dict = {name: (library, class_name)}
+
+ # save model index config
+ self.register_to_config(**register_dict)
+
+ # set models
+ setattr(self, name, module)
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ safe_serialization: bool = False,
+ ):
+ """
+ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
+ a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
+ method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ safe_serialization (`bool`, *optional*, defaults to `False`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+ """
+ self.save_config(save_directory)
+
+ model_index_dict = dict(self.config)
+ model_index_dict.pop("_class_name")
+ model_index_dict.pop("_diffusers_version")
+ model_index_dict.pop("_module", None)
+
+ expected_modules, optional_kwargs = self._get_signature_keys(self)
+
+ def is_saveable_module(name, value):
+ if name not in expected_modules:
+ return False
+ if name in self._optional_components and value[0] is None:
+ return False
+ return True
+
+ model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
+
+ for pipeline_component_name in model_index_dict.keys():
+ sub_model = getattr(self, pipeline_component_name)
+ model_cls = sub_model.__class__
+
+ save_method_name = None
+ # search for the model's base class in LOADABLE_CLASSES
+ for library_name, library_classes in LOADABLE_CLASSES.items():
+ library = importlib.import_module(library_name)
+ for base_class, save_load_methods in library_classes.items():
+ class_candidate = getattr(library, base_class, None)
+ if class_candidate is not None and issubclass(model_cls, class_candidate):
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
+ save_method_name = save_load_methods[0]
+ break
+ if save_method_name is not None:
+ break
+
+ save_method = getattr(sub_model, save_method_name)
+
+ # Call the save method with the argument safe_serialization only if it's supported
+ save_method_signature = inspect.signature(save_method)
+ save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
+ if save_method_accept_safe:
+ save_method(
+ os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization
+ )
+ else:
+ save_method(os.path.join(save_directory, pipeline_component_name))
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
+ if torch_device is None:
+ return self
+
+ module_names, _, _ = self.extract_init_dict(dict(self.config))
+ for name in module_names.keys():
+ module = getattr(self, name)
+ if isinstance(module, torch.nn.Module):
+ if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
+ logger.warning(
+ "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
+ " is not recommended to move them to `cpu` as running them will fail. Please make"
+ " sure to use an accelerator to run the pipeline in inference, due to the lack of"
+ " support for`float16` operations on this device in PyTorch. Please, remove the"
+ " `torch_dtype=torch.float16` argument, or use another device for inference."
+ )
+ module.to(torch_device)
+ return self
+
+ @property
+ def device(self) -> torch.device:
+ r"""
+ Returns:
+ `torch.device`: The torch device on which the pipeline is located.
+ """
+ module_names, _, _ = self.extract_init_dict(dict(self.config))
+ for name in module_names.keys():
+ module = getattr(self, name)
+ if isinstance(module, torch.nn.Module):
+ return module.device
+ return torch.device("cpu")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
+
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
+ https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
+ `CompVis/ldm-text2im-large-256`.
+ - A path to a *directory* containing pipeline weights saved using
+ [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ custom_pipeline (`str`, *optional*):
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+
+ Can be either:
+
+ - A string, the *repo id* of a custom pipeline hosted inside a model repo on
+ https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
+ like `hf-internal-testing/diffusers-dummy-pipeline`.
+
+
+
+ It is required that the model repo has a file, called `pipeline.py` that defines the custom
+ pipeline.
+
+
+
+ - A string, the *file name* of a community pipeline hosted on GitHub under
+ https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
+ match exactly the file name without `.py` located under the above link, *e.g.*
+ `clip_guided_stable_diffusion`.
+
+
+
+ Community pipelines are always loaded from the current `main` branch of GitHub.
+
+
+
+ - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
+
+
+
+ It is required that the directory has a file, called `pipeline.py` that defines the custom
+ pipeline.
+
+
+
+ For more information on how to load and create custom pipelines, please have a look at [Loading and
+ Adding Custom
+ Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
+
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information. specify the folder name here.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
+ setting this argument to `True` will raise an error.
+ return_cached_folder (`bool`, *optional*, defaults to `False`):
+ If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
+ specific pipeline class. The overwritten components are then directly passed to the pipelines
+ `__init__` method. See example below for more information.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ Examples:
+
+ ```py
+ >>> from diffusers import DiffusionPipeline
+
+ >>> # Download pipeline from huggingface.co and cache.
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
+
+ >>> # Download pipeline that requires an authorization token
+ >>> # For more information on access tokens, please refer to this section
+ >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
+ >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+
+ >>> # Use a different scheduler
+ >>> from diffusers import LMSDiscreteScheduler
+
+ >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
+ >>> pipeline.scheduler = scheduler
+ ```
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ custom_pipeline = kwargs.pop("custom_pipeline", None)
+ provider = kwargs.pop("provider", None)
+ sess_options = kwargs.pop("sess_options", None)
+ device_map = kwargs.pop("device_map", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+ return_cached_folder = kwargs.pop("return_cached_folder", False)
+
+ # 1. Download the checkpoints and configs
+ # use snapshot download here to get it working from from_pretrained
+ if not os.path.isdir(pretrained_model_name_or_path):
+ config_dict = cls.load_config(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ # make sure we only download sub-folders and `diffusers` filenames
+ folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
+ allow_patterns = [os.path.join(k, "*") for k in folder_names]
+ allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
+
+ # make sure we don't download flax weights
+ ignore_patterns = ["*.msgpack"]
+
+ if custom_pipeline is not None:
+ allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
+
+ if cls != DiffusionPipeline:
+ requested_pipeline_class = cls.__name__
+ else:
+ requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
+ user_agent = {"pipeline_class": requested_pipeline_class}
+ if custom_pipeline is not None:
+ user_agent["custom_pipeline"] = custom_pipeline
+ user_agent = http_user_agent(user_agent)
+
+ if is_safetensors_available():
+ info = model_info(
+ pretrained_model_name_or_path,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ if is_safetensors_compatible(info):
+ ignore_patterns.append("*.bin")
+
+ # download all allow_patterns
+ cached_folder = snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ allow_patterns=allow_patterns,
+ ignore_patterns=ignore_patterns,
+ user_agent=user_agent,
+ )
+ else:
+ cached_folder = pretrained_model_name_or_path
+
+ config_dict = cls.load_config(cached_folder)
+
+ # 2. Load the pipeline class, if using custom module then load it from the hub
+ # if we load from explicit class, let's use it
+ if custom_pipeline is not None:
+ if custom_pipeline.endswith(".py"):
+ path = Path(custom_pipeline)
+ # decompose into folder & file
+ file_name = path.name
+ custom_pipeline = path.parent.absolute()
+ else:
+ file_name = CUSTOM_PIPELINE_FILE_NAME
+
+ pipeline_class = get_class_from_dynamic_module(
+ custom_pipeline, module_file=file_name, cache_dir=custom_pipeline
+ )
+ elif cls != DiffusionPipeline:
+ pipeline_class = cls
+ else:
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
+ pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
+
+ # To be removed in 1.0.0
+ if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
+ version.parse(config_dict["_diffusers_version"]).base_version
+ ) <= version.parse("0.5.1"):
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
+
+ pipeline_class = StableDiffusionInpaintPipelineLegacy
+
+ deprecation_message = (
+ "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
+ f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
+ " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
+ " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
+ f" checkpoint {pretrained_model_name_or_path} to the format of"
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
+ " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
+ )
+ deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
+
+ # some modules can be passed directly to the init
+ # in this case they are already instantiated in `kwargs`
+ # extract them here
+ expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
+
+ init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
+
+ # define init kwargs
+ init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
+ init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
+
+ # remove `null` components
+ def load_module(name, value):
+ if value[0] is None:
+ return False
+ if name in passed_class_obj and passed_class_obj[name] is None:
+ return False
+ return True
+
+ init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
+
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
+ )
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `device_map=None`."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ if low_cpu_mem_usage is False and device_map is not None:
+ raise ValueError(
+ f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
+ )
+
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ # 3. Load each module in the pipeline
+ for name, (library_name, class_name) in init_dict.items():
+ # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
+ if class_name.startswith("Flax"):
+ class_name = class_name[4:]
+
+ is_pipeline_module = hasattr(pipelines, library_name)
+ loaded_sub_model = None
+
+ # if the model is in a pipeline module, then we load it from the pipeline
+ if name in passed_class_obj:
+ # 1. check that passed_class_obj has correct parent class
+ if not is_pipeline_module:
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
+
+ expected_class_obj = None
+ for class_name, class_candidate in class_candidates.items():
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
+ expected_class_obj = class_candidate
+
+ if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
+ raise ValueError(
+ f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
+ f" {expected_class_obj}"
+ )
+ else:
+ logger.warning(
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
+ " has the correct type"
+ )
+
+ # set passed class object
+ loaded_sub_model = passed_class_obj[name]
+ elif is_pipeline_module:
+ pipeline_module = getattr(pipelines, library_name)
+ class_obj = getattr(pipeline_module, class_name)
+ importable_classes = ALL_IMPORTABLE_CLASSES
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
+ else:
+ # else we just import it from the library.
+ library = importlib.import_module(library_name)
+
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
+
+ if loaded_sub_model is None:
+ load_method_name = None
+ for class_name, class_candidate in class_candidates.items():
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
+ load_method_name = importable_classes[class_name][1]
+
+ if load_method_name is None:
+ none_module = class_obj.__module__
+ is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
+ TRANSFORMERS_DUMMY_MODULES_FOLDER
+ )
+ if is_dummy_path and "dummy" in none_module:
+ # call class_obj for nice error message of missing requirements
+ class_obj()
+
+ raise ValueError(
+ f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
+ f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
+ )
+
+ load_method = getattr(class_obj, load_method_name)
+ loading_kwargs = {}
+
+ if issubclass(class_obj, torch.nn.Module):
+ loading_kwargs["torch_dtype"] = torch_dtype
+ if issubclass(class_obj, diffusers.OnnxRuntimeModel):
+ loading_kwargs["provider"] = provider
+ loading_kwargs["sess_options"] = sess_options
+
+ is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
+ is_transformers_model = (
+ is_transformers_available()
+ and issubclass(class_obj, PreTrainedModel)
+ and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
+ )
+
+ # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
+ # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
+ # This makes sure that the weights won't be initialized which significantly speeds up loading.
+ if is_diffusers_model or is_transformers_model:
+ loading_kwargs["device_map"] = device_map
+ loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
+
+ # check if the module is in a subdirectory
+ if os.path.isdir(os.path.join(cached_folder, name)):
+ loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
+ else:
+ # else load from the root directory
+ loaded_sub_model = load_method(cached_folder, **loading_kwargs)
+
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
+
+ # 4. Potentially add passed objects if expected
+ missing_modules = set(expected_modules) - set(init_kwargs.keys())
+ passed_modules = list(passed_class_obj.keys())
+ optional_modules = pipeline_class._optional_components
+ if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
+ for module in missing_modules:
+ init_kwargs[module] = passed_class_obj.get(module, None)
+ elif len(missing_modules) > 0:
+ passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
+ raise ValueError(
+ f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
+ )
+
+ # 5. Instantiate the pipeline
+ model = pipeline_class(**init_kwargs)
+
+ if return_cached_folder:
+ return model, cached_folder
+ return model
+
+ @staticmethod
+ def _get_signature_keys(obj):
+ parameters = inspect.signature(obj.__init__).parameters
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
+ optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
+ expected_modules = set(required_parameters.keys()) - set(["self"])
+ return expected_modules, optional_parameters
+
+ @property
+ def components(self) -> Dict[str, Any]:
+ r"""
+
+ The `self.components` property can be useful to run different pipelines with the same weights and
+ configurations to not have to re-allocate memory.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import (
+ ... StableDiffusionPipeline,
+ ... StableDiffusionImg2ImgPipeline,
+ ... StableDiffusionInpaintPipeline,
+ ... )
+
+ >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
+ >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
+ ```
+
+ Returns:
+ A dictionaly containing all the modules needed to initialize the pipeline.
+ """
+ expected_modules, optional_parameters = self._get_signature_keys(self)
+ components = {
+ k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
+ }
+
+ if set(components.keys()) != expected_modules:
+ raise ValueError(
+ f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
+ f" {expected_modules} to be defined, but {components} are defined."
+ )
+
+ return components
+
+ @staticmethod
+ def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ if images.shape[-1] == 1:
+ # special case for grayscale (single channel) images
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
+ else:
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+ def progress_bar(self, iterable=None, total=None):
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ elif not isinstance(self._progress_bar_config, dict):
+ raise ValueError(
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
+ )
+
+ if iterable is not None:
+ return tqdm(iterable, **self._progress_bar_config)
+ elif total is not None:
+ return tqdm(total=total, **self._progress_bar_config)
+ else:
+ raise ValueError("Either `total` or `iterable` has to be defined.")
+
+ def set_progress_bar_config(self, **kwargs):
+ self._progress_bar_config = kwargs
+
+ def enable_xformers_memory_efficient_attention(self):
+ r"""
+ Enable memory efficient attention as implemented in xformers.
+
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
+ time. Speed up at training time is not guaranteed.
+
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
+ is used.
+ """
+ self.set_use_memory_efficient_attention_xformers(True)
+
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention as implemented in xformers.
+ """
+ self.set_use_memory_efficient_attention_xformers(False)
+
+ def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
+ # Recursively walk through all the children.
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
+ # gets the message
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
+ module.set_use_memory_efficient_attention_xformers(valid)
+
+ for child in module.children():
+ fn_recursive_set_mem_eff(child)
+
+ module_names, _, _ = self.extract_init_dict(dict(self.config))
+ for module_name in module_names:
+ module = getattr(self, module_name)
+ if isinstance(module, torch.nn.Module):
+ fn_recursive_set_mem_eff(module)
diff --git a/diffusers/pipelines/README.md b/diffusers/pipelines/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c3202db0270c29e4827d16233f67915a1424697e
--- /dev/null
+++ b/diffusers/pipelines/README.md
@@ -0,0 +1,173 @@
+# 🧨 Diffusers Pipelines
+
+Pipelines provide a simple way to run state-of-the-art diffusion models in inference.
+Most diffusion systems consist of multiple independently-trained models and highly adaptable scheduler
+components - all of which are needed to have a functioning end-to-end diffusion system.
+
+As an example, [Stable Diffusion](https://huggingface.co/blog/stable_diffusion) has three independently trained models:
+- [Autoencoder](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/vae.py#L392)
+- [Conditional Unet](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/unet_2d_condition.py#L12)
+- [CLIP text encoder](https://huggingface.co/docs/transformers/v4.21.2/en/model_doc/clip#transformers.CLIPTextModel)
+- a scheduler component, [scheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py),
+- a [CLIPFeatureExtractor](https://huggingface.co/docs/transformers/v4.21.2/en/model_doc/clip#transformers.CLIPFeatureExtractor),
+- as well as a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py).
+All of these components are necessary to run stable diffusion in inference even though they were trained
+or created independently from each other.
+
+To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API.
+More specifically, we strive to provide pipelines that
+- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
+- 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section),
+- 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)),
+- 4. can easily be contributed by the community (see the [Contribution](#contribution) section).
+
+**Note** that pipelines do not (and should not) offer any training functionality.
+If you are looking for *official* training examples, please have a look at [examples](https://github.com/huggingface/diffusers/tree/main/examples).
+
+
+## Pipelines Summary
+
+The following table summarizes all officially supported pipelines, their corresponding paper, and if
+available a colab notebook to directly try them out.
+
+| Pipeline | Source | Tasks | Colab
+|-------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|:---:|:---:|
+| [dance diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/Harmonai-org/sample-generator) | *Unconditional Audio Generation* |
+| [ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | *Unconditional Image Generation* |
+| [ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | *Unconditional Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
+| [latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Text-to-Image Generation* |
+| [latent_diffusion_uncond](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Unconditional Image Generation* |
+| [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* |
+| [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
+| [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
+| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb)
+| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
+| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
+| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* |
+
+**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
+However, most of them can be adapted to use different scheduler components or even different model components. Some pipeline examples are shown in the [Examples](#examples) below.
+
+## Pipelines API
+
+Diffusion models often consist of multiple independently-trained models or other previously existing components.
+
+
+Each model has been trained independently on a different task and the scheduler can easily be swapped out and replaced with a different one.
+During inference, we however want to be able to easily load all components and use them in inference - even if one component, *e.g.* CLIP's text encoder, originates from a different library, such as [Transformers](https://github.com/huggingface/transformers). To that end, all pipelines provide the following functionality:
+
+- [`from_pretrained` method](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L139) that accepts a Hugging Face Hub repository id, *e.g.* [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) or a path to a local directory, *e.g.*
+"./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [runwayml/stable-diffusion-v1-5/model_index.json](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), which defines all components that should be
+loaded into the pipelines. More specifically, for each model/component one needs to define the format `: ["", ""]`. `` is the attribute name given to the loaded instance of `` which can be found in the library or pipeline folder called `""`.
+- [`save_pretrained`](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L90) that accepts a local path, *e.g.* `./stable-diffusion` under which all models/components of the pipeline will be saved. For each component/model a folder is created inside the local path that is named after the given attribute name, *e.g.* `./stable_diffusion/unet`.
+In addition, a `model_index.json` file is created at the root of the local path, *e.g.* `./stable_diffusion/model_index.json` so that the complete pipeline can again be instantiated
+from the local path.
+- [`to`](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L118) which accepts a `string` or `torch.device` to move all models that are of type `torch.nn.Module` to the passed device. The behavior is fully analogous to [PyTorch's `to` method](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to).
+- [`__call__`] method to use the pipeline in inference. `__call__` defines inference logic of the pipeline and should ideally encompass all aspects of it, from pre-processing to forwarding tensors to the different models and schedulers, as well as post-processing. The API of the `__call__` method can strongly vary from pipeline to pipeline. *E.g.* a text-to-image pipeline, such as [`StableDiffusionPipeline`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) should accept among other things the text prompt to generate the image. A pure image generation pipeline, such as [DDPMPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/ddpm) on the other hand can be run without providing any inputs. To better understand what inputs can be adapted for
+each pipeline, one should look directly into the respective pipeline.
+
+**Note**: All pipelines have PyTorch's autograd disabled by decorating the `__call__` method with a [`torch.no_grad`](https://pytorch.org/docs/stable/generated/torch.no_grad.html) decorator because pipelines should
+not be used for training. If you want to store the gradients during the forward pass, we recommend writing your own pipeline, see also our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community)
+
+## Contribution
+
+We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire
+all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
+
+- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file itself, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline.
+- **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and
+use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most
+logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method.
+- **Easy-to-tweak**: Certain pipelines will not be able to handle all use cases and tasks that you might like them to. If you want to use a certain pipeline for a specific use case that is not yet supported, you might have to copy the pipeline file and tweak the code to your needs. We try to make the pipeline code as readable as possible so that each part –from pre-processing to diffusing to post-processing– can easily be adapted. If you would like the community to benefit from your customized pipeline, we would love to see a contribution to our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community). If you feel that an important pipeline should be part of the official pipelines but isn't, a contribution to the [official pipelines](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines) would be even better.
+- **One-purpose-only**: Pipelines should be used for one task and one task only. Even if two tasks are very similar from a modeling point of view, *e.g.* image2image translation and in-painting, pipelines shall be used for one task only to keep them *easy-to-tweak* and *readable*.
+
+## Examples
+
+### Text-to-Image generation with Stable Diffusion
+
+```python
+# make sure you're logged in with `huggingface-cli login`
+from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
+
+pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+pipe = pipe.to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).images[0]
+
+image.save("astronaut_rides_horse.png")
+```
+
+### Image-to-Image text-guided generation with Stable Diffusion
+
+The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
+
+```python
+import requests
+from PIL import Image
+from io import BytesIO
+
+from diffusers import StableDiffusionImg2ImgPipeline
+
+# load the pipeline
+device = "cuda"
+pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ revision="fp16",
+ torch_dtype=torch.float16,
+).to(device)
+
+# let's download an initial image
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = init_image.resize((768, 512))
+
+prompt = "A fantasy landscape, trending on artstation"
+
+images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
+
+images[0].save("fantasy_landscape.png")
+```
+You can also run this example on colab [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
+
+### Tweak prompts reusing seeds and latents
+
+You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb).
+
+
+### In-painting using Stable Diffusion
+
+The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by providing a mask and text prompt.
+
+```python
+import PIL
+import requests
+import torch
+from io import BytesIO
+
+from diffusers import StableDiffusionInpaintPipeline
+
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+init_image = download_image(img_url).resize((512, 512))
+mask_image = download_image(mask_url).resize((512, 512))
+
+pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting",
+ revision="fp16",
+ torch_dtype=torch.float16,
+)
+pipe = pipe.to("cuda")
+
+prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+```
+
+You can also run this example on colab [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
diff --git a/diffusers/pipelines/__init__.py b/diffusers/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5aba302042b989f5e891633be98110f76988ae2
--- /dev/null
+++ b/diffusers/pipelines/__init__.py
@@ -0,0 +1,48 @@
+from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
+
+
+if is_torch_available():
+ from .dance_diffusion import DanceDiffusionPipeline
+ from .ddim import DDIMPipeline
+ from .ddpm import DDPMPipeline
+ from .latent_diffusion import LDMSuperResolutionPipeline
+ from .latent_diffusion_uncond import LDMPipeline
+ from .pndm import PNDMPipeline
+ from .repaint import RePaintPipeline
+ from .score_sde_ve import ScoreSdeVePipeline
+ from .stochastic_karras_ve import KarrasVePipeline
+else:
+ from ..utils.dummy_pt_objects import * # noqa F403
+
+if is_torch_available() and is_transformers_available():
+ from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
+ from .latent_diffusion import LDMTextToImagePipeline
+ from .stable_diffusion import (
+ CycleDiffusionPipeline,
+ StableDiffusionImageVariationPipeline,
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionInpaintPipelineLegacy,
+ StableDiffusionPipeline,
+ StableDiffusionUpscalePipeline,
+ )
+ from .stable_diffusion_safe import StableDiffusionPipelineSafe
+ from .versatile_diffusion import (
+ VersatileDiffusionDualGuidedPipeline,
+ VersatileDiffusionImageVariationPipeline,
+ VersatileDiffusionPipeline,
+ VersatileDiffusionTextToImagePipeline,
+ )
+ from .vq_diffusion import VQDiffusionPipeline
+
+if is_transformers_available() and is_onnx_available():
+ from .stable_diffusion import (
+ OnnxStableDiffusionImg2ImgPipeline,
+ OnnxStableDiffusionInpaintPipeline,
+ OnnxStableDiffusionInpaintPipelineLegacy,
+ OnnxStableDiffusionPipeline,
+ StableDiffusionOnnxPipeline,
+ )
+
+if is_transformers_available() and is_flax_available():
+ from .stable_diffusion import FlaxStableDiffusionPipeline
diff --git a/diffusers/pipelines/alt_diffusion/__init__.py b/diffusers/pipelines/alt_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..09d0d9b7852c4babfe26c33874bcb1bf52271b39
--- /dev/null
+++ b/diffusers/pipelines/alt_diffusion/__init__.py
@@ -0,0 +1,34 @@
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+
+import PIL
+from PIL import Image
+
+from ...utils import BaseOutput, is_torch_available, is_transformers_available
+
+
+@dataclass
+# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt
+class AltDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Alt Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, or `None` if safety checking could not be performed.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: Optional[List[bool]]
+
+
+if is_transformers_available() and is_torch_available():
+ from .modeling_roberta_series import RobertaSeriesModelWithTransformation
+ from .pipeline_alt_diffusion import AltDiffusionPipeline
+ from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline
diff --git a/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e92314162d3424935082ead79da9694f2569fe1
--- /dev/null
+++ b/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py
@@ -0,0 +1,110 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel
+from transformers.utils import ModelOutput
+
+
+@dataclass
+class TransformationModelOutput(ModelOutput):
+ """
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The text embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ projection_state: Optional[torch.FloatTensor] = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class RobertaSeriesConfig(XLMRobertaConfig):
+ def __init__(
+ self,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ project_dim=512,
+ pooler_fn="cls",
+ learn_encoder=False,
+ use_attention_mask=True,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+ self.use_attention_mask = use_attention_mask
+
+
+class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+ base_model_prefix = "roberta"
+ config_class = RobertaSeriesConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.roberta = XLMRobertaModel(config)
+ self.transformation = nn.Linear(config.hidden_size, config.project_dim)
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ):
+ r""" """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.base_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ projection_state = self.transformation(outputs.last_hidden_state)
+
+ return TransformationModelOutput(
+ projection_state=projection_state,
+ last_hidden_state=outputs.last_hidden_state,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb64a34a0bd89f45dd27e4143aa0c3093d4d6e65
--- /dev/null
+++ b/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
@@ -0,0 +1,579 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+
+from diffusers.utils import is_accelerate_available
+from packaging import version
+from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import deprecate, logging
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
+class AltDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Alt Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`RobertaSeriesModelWithTransformation`]):
+ Frozen text-encoder. Alt Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`XLMRobertaTokenizer`):
+ Tokenizer of class
+ [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: RobertaSeriesModelWithTransformation,
+ tokenizer: XLMRobertaTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ if isinstance(self.unet.config.attention_head_dim, int):
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ else:
+ # if `attention_head_dim` is a list, take the smallest head size
+ slice_size = min(self.unet.config.attention_head_dim)
+
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
+ # fix by only offloading self.safety_checker for now
+ cpu_offload(self.safety_checker.vision_model, device)
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 10. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..346f5f727bb87c66e4777894fc4f6726fe82b6f3
--- /dev/null
+++ b/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -0,0 +1,601 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from diffusers.utils import is_accelerate_available
+from packaging import version
+from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import PIL_INTERPOLATION, deprecate, logging
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
+def preprocess(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
+class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image to image generation using Alt Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`RobertaSeriesModelWithTransformation`]):
+ Frozen text-encoder. Alt Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`XLMRobertaTokenizer`):
+ Tokenizer of class
+ [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: RobertaSeriesModelWithTransformation,
+ tokenizer: XLMRobertaTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ if isinstance(self.unet.config.attention_head_dim, int):
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ else:
+ # if `attention_head_dim` is a list, take the smallest head size
+ slice_size = min(self.unet.config.attention_head_dim)
+
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
+ # fix by only offloading self.safety_checker for now
+ cpu_offload(self.safety_checker.vision_model, device)
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, strength, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ image = image.to(device=device, dtype=dtype)
+ init_latent_dist = self.vae.encode(image).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = 0.18215 * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ message = "Please use `image` instead of `init_image`."
+ init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
+ image = init_image or image
+
+ # 1. Check inputs
+ self.check_inputs(prompt, strength, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Preprocess image
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess(image)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 9. Post-processing
+ image = self.decode_latents(latents)
+
+ # 10. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 11. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/dance_diffusion/__init__.py b/diffusers/pipelines/dance_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ad34fc52aaa61f9313cae32d7bb39acad831104
--- /dev/null
+++ b/diffusers/pipelines/dance_diffusion/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_dance_diffusion import DanceDiffusionPipeline
diff --git a/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..48d16889a030217b5d203233678a10e3eb7ae9d2
--- /dev/null
+++ b/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
@@ -0,0 +1,119 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...pipeline_utils import AudioPipelineOutput, DiffusionPipeline
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class DanceDiffusionPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`IPNDMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 100,
+ generator: Optional[torch.Generator] = None,
+ audio_length_in_s: Optional[float] = None,
+ return_dict: bool = True,
+ ) -> Union[AudioPipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of audio samples to generate.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at
+ the expense of slower inference.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
+ The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
+ `sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if audio_length_in_s is None:
+ audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate
+
+ sample_size = audio_length_in_s * self.unet.sample_rate
+
+ down_scale_factor = 2 ** len(self.unet.up_blocks)
+ if sample_size < 3 * down_scale_factor:
+ raise ValueError(
+ f"{audio_length_in_s} is too small. Make sure it's bigger or equal to"
+ f" {3 * down_scale_factor / self.unet.sample_rate}."
+ )
+
+ original_sample_size = int(sample_size)
+ if sample_size % down_scale_factor != 0:
+ sample_size = ((audio_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor
+ logger.info(
+ f"{audio_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled"
+ f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising"
+ " process."
+ )
+ sample_size = int(sample_size)
+
+ dtype = next(iter(self.unet.parameters())).dtype
+ audio = torch.randn(
+ (batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device, dtype=dtype
+ )
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
+ self.scheduler.timesteps = self.scheduler.timesteps.to(dtype)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(audio, t).sample
+
+ # 2. compute previous image: x_t -> t_t-1
+ audio = self.scheduler.step(model_output, t, audio).prev_sample
+
+ audio = audio.clamp(-1, 1).float().cpu().numpy()
+
+ audio = audio[:, :, :original_sample_size]
+
+ if not return_dict:
+ return (audio,)
+
+ return AudioPipelineOutput(audios=audio)
diff --git a/diffusers/pipelines/ddim/__init__.py b/diffusers/pipelines/ddim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fd31868a88ac0d9ec7118574f21a9d8a1d4069b
--- /dev/null
+++ b/diffusers/pipelines/ddim/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_ddim import DDIMPipeline
diff --git a/diffusers/pipelines/ddim/pipeline_ddim.py b/diffusers/pipelines/ddim/pipeline_ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9e590dea64621ea9eada0fd5d962e58943a2775
--- /dev/null
+++ b/diffusers/pipelines/ddim/pipeline_ddim.py
@@ -0,0 +1,126 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...utils import deprecate
+
+
+class DDIMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ use_clipped_model_output: Optional[bool] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ eta (`float`, *optional*, defaults to 0.0):
+ The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ use_clipped_model_output (`bool`, *optional*, defaults to `None`):
+ if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed
+ downstream to the scheduler. So use `None` for schedulers which don't support this argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
+ message = (
+ f"The `generator` device is `{generator.device}` and does not match the pipeline "
+ f"device `{self.device}`, so the `generator` will be ignored. "
+ f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
+ )
+ deprecate(
+ "generator.device == 'cpu'",
+ "0.11.0",
+ message,
+ )
+ generator = None
+
+ # Sample gaussian noise to begin loop
+ if isinstance(self.unet.sample_size, int):
+ image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
+ else:
+ image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
+
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ image = torch.randn(image_shape, generator=generator)
+ image = image.to(self.device)
+ else:
+ image = torch.randn(image_shape, generator=generator, device=self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # eta corresponds to η in paper and should be between [0, 1]
+ # do x_t -> x_t-1
+ image = self.scheduler.step(
+ model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
+ ).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/ddpm/__init__.py b/diffusers/pipelines/ddpm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8889bdae1224e91916e0f8454bafba0ee566f3b9
--- /dev/null
+++ b/diffusers/pipelines/ddpm/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_ddpm import DDPMPipeline
diff --git a/diffusers/pipelines/ddpm/pipeline_ddpm.py b/diffusers/pipelines/ddpm/pipeline_ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..114a38a5fec7a471ed60be1c38ace65f86c903dd
--- /dev/null
+++ b/diffusers/pipelines/ddpm/pipeline_ddpm.py
@@ -0,0 +1,127 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...utils import deprecate
+
+
+class DDPMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ num_inference_steps: int = 1000,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 1000):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ message = (
+ "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
+ " DDPMScheduler.from_pretrained(, prediction_type='epsilon')`."
+ )
+ predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
+
+ if predict_epsilon is not None:
+ new_config = dict(self.scheduler.config)
+ new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
+ self.scheduler._internal_dict = FrozenDict(new_config)
+
+ if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
+ message = (
+ f"The `generator` device is `{generator.device}` and does not match the pipeline "
+ f"device `{self.device}`, so the `generator` will be ignored. "
+ f'Please use `torch.Generator(device="{self.device}")` instead.'
+ )
+ deprecate(
+ "generator.device == 'cpu'",
+ "0.11.0",
+ message,
+ )
+ generator = None
+
+ # Sample gaussian noise to begin loop
+ if isinstance(self.unet.sample_size, int):
+ image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
+ else:
+ image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
+
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ image = torch.randn(image_shape, generator=generator)
+ image = image.to(self.device)
+ else:
+ image = torch.randn(image_shape, generator=generator, device=self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. compute previous image: x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/latent_diffusion/__init__.py b/diffusers/pipelines/latent_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5544527ff5877bb2c725c8b375cd5b03060d6a21
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion/__init__.py
@@ -0,0 +1,7 @@
+# flake8: noqa
+from ...utils import is_transformers_available
+from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline
+
+
+if is_transformers_available():
+ from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline
diff --git a/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e903cb836a32c85f442f30ccdea08cfc67425dd
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
@@ -0,0 +1,711 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_outputs import BaseModelOutput
+from transformers.modeling_utils import PreTrainedModel
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.utils import logging
+
+from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+
+
+class LDMTextToImagePipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ bert ([`LDMBertModel`]):
+ Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
+ tokenizer (`transformers.BertTokenizer`):
+ Tokenizer of class
+ [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vqvae: Union[VQModel, AutoencoderKL],
+ bert: PreTrainedModel,
+ tokenizer: PreTrainedTokenizer,
+ unet: Union[UNet2DModel, UNet2DConditionModel],
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ ):
+ super().__init__()
+ self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+ self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 1.0,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at
+ the, usually at the expense of lower image quality.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get unconditional embeddings for classifier free guidance
+ if guidance_scale != 1.0:
+ uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
+ uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
+ text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
+
+ latents = torch.randn(
+ (batch_size, self.unet.in_channels, height // 8, width // 8),
+ generator=generator,
+ )
+ latents = latents.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+
+ extra_kwargs = {}
+ if accepts_eta:
+ extra_kwargs["eta"] = eta
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ if guidance_scale == 1.0:
+ # guidance_scale of 1 means no guidance
+ latents_input = latents
+ context = text_embeddings
+ else:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ latents_input = torch.cat([latents] * 2)
+ context = torch.cat([uncond_embeddings, text_embeddings])
+
+ # predict the noise residual
+ noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
+ # perform guidance
+ if guidance_scale != 1.0:
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vqvae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
+
+
+################################################################################
+# Code for the text transformer model
+################################################################################
+""" PyTorch LDMBERT model."""
+
+
+logger = logging.get_logger(__name__)
+
+LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "ldm-bert",
+ # See all LDMBert models at https://huggingface.co/models?filter=ldmbert
+]
+
+
+LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "ldm-bert": "https://huggingface.co/valhalla/ldm-bert/blob/main/config.json",
+}
+
+
+""" LDMBERT model configuration"""
+
+
+class LDMBertConfig(PretrainedConfig):
+ model_type = "ldmbert"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ max_position_embeddings=77,
+ encoder_layers=32,
+ encoder_ffn_dim=5120,
+ encoder_attention_heads=8,
+ head_dim=64,
+ encoder_layerdrop=0.0,
+ activation_function="gelu",
+ d_model=1280,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ classifier_dropout=0.0,
+ scale_embedding=False,
+ use_cache=True,
+ pad_token_id=0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.head_dim = head_dim
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.classifier_dropout = classifier_dropout
+ self.use_cache = use_cache
+ self.num_hidden_layers = encoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert
+class LDMBertAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ head_dim: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = False,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = head_dim
+ self.inner_dim = head_dim * num_heads
+
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.out_proj = nn.Linear(self.inner_dim, embed_dim)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class LDMBertEncoderLayer(nn.Module):
+ def __init__(self, config: LDMBertConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.self_attn = LDMBertAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ head_dim=config.head_dim,
+ dropout=config.attention_dropout,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ layer_head_mask: torch.FloatTensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert
+class LDMBertPreTrainedModel(PreTrainedModel):
+ config_class = LDMBertConfig
+ base_model_prefix = "model"
+ _supports_gradient_checkpointing = True
+ _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (LDMBertEncoder,)):
+ module.gradient_checkpointing = value
+
+ @property
+ def dummy_inputs(self):
+ pad_token = self.config.pad_token_id
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+ dummy_inputs = {
+ "attention_mask": input_ids.ne(pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+
+class LDMBertEncoder(LDMBertPreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`LDMBertEncoderLayer`].
+
+ Args:
+ config: LDMBertConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: LDMBertConfig):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+
+ embed_dim = config.d_model
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim)
+ self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim)
+ self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self.layer_norm = nn.LayerNorm(embed_dim)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ seq_len = input_shape[1]
+ if position_ids is None:
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1))
+ embed_pos = self.embed_positions(position_ids)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class LDMBertModel(LDMBertPreTrainedModel):
+ _no_split_modules = []
+
+ def __init__(self, config: LDMBertConfig):
+ super().__init__(config)
+ self.model = LDMBertEncoder(config)
+ self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ return outputs
diff --git a/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..09bdca54accfb51cd12afa1a103d2f88a909215b
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
@@ -0,0 +1,171 @@
+import inspect
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+
+import PIL
+
+from ...models import UNet2DModel, VQModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import PIL_INTERPOLATION, deprecate
+
+
+def preprocess(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+class LDMSuperResolutionPipeline(DiffusionPipeline):
+ r"""
+ A pipeline for image super-resolution using Latent
+
+ This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations.
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
+ [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vqvae: VQModel,
+ unet: UNet2DModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ ):
+ super().__init__()
+ self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[torch.Tensor, PIL.Image.Image],
+ batch_size: Optional[int] = 1,
+ num_inference_steps: Optional[int] = 100,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ batch_size (`int`, *optional*, defaults to 1):
+ Number of images to generate.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ message = "Please use `image` instead of `init_image`."
+ init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
+ image = init_image or image
+
+ if isinstance(image, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(image, torch.Tensor):
+ batch_size = image.shape[0]
+ else:
+ raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
+
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess(image)
+
+ height, width = image.shape[-2:]
+
+ # in_channels should be 6: 3 for latents, 3 for low resolution image
+ latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
+ latents_dtype = next(self.unet.parameters()).dtype
+
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype)
+ latents = latents.to(self.device)
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+
+ image = image.to(device=self.device, dtype=latents_dtype)
+
+ # set timesteps and move to the correct device
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
+ timesteps_tensor = self.scheduler.timesteps
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_kwargs = {}
+ if accepts_eta:
+ extra_kwargs["eta"] = eta
+
+ for t in self.progress_bar(timesteps_tensor):
+ # concat latents and low resolution image in the channel dimension.
+ latents_input = torch.cat([latents, image], dim=1)
+ latents_input = self.scheduler.scale_model_input(latents_input, t)
+ # predict the noise residual
+ noise_pred = self.unet(latents_input, t).sample
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
+
+ # decode the image latents with the VQVAE
+ image = self.vqvae.decode(latents).sample
+ image = torch.clamp(image, -1.0, 1.0)
+ image = image / 2 + 0.5
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/latent_diffusion_uncond/__init__.py b/diffusers/pipelines/latent_diffusion_uncond/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0826ca7536c706f9bc1f310c157068efbca7f0b3
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion_uncond/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_latent_diffusion_uncond import LDMPipeline
diff --git a/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
new file mode 100644
index 0000000000000000000000000000000000000000..5345c4e5625ee519a411b4fd80468fc991757165
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
@@ -0,0 +1,111 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel, VQModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import DDIMScheduler
+
+
+class LDMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
+ """
+
+ def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
+ super().__init__()
+ self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ Number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ latents = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ latents = latents.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+
+ extra_kwargs = {}
+ if accepts_eta:
+ extra_kwargs["eta"] = eta
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
+ # predict the noise residual
+ noise_prediction = self.unet(latent_model_input, t).sample
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
+
+ # decode the image latents with the VAE
+ image = self.vqvae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/pndm/__init__.py b/diffusers/pipelines/pndm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc46aaab9fa26e83b49c26843d854e217742664
--- /dev/null
+++ b/diffusers/pipelines/pndm/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_pndm import PNDMPipeline
diff --git a/diffusers/pipelines/pndm/pipeline_pndm.py b/diffusers/pipelines/pndm/pipeline_pndm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef7062dea19cd34d533dbf7eee25fd3d0c21b4f8
--- /dev/null
+++ b/diffusers/pipelines/pndm/pipeline_pndm.py
@@ -0,0 +1,96 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import PNDMScheduler
+
+
+class PNDMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image.
+ """
+
+ unet: UNet2DModel
+ scheduler: PNDMScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, `optional`, defaults to 1): The number of images to generate.
+ num_inference_steps (`int`, `optional`, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator`, `optional`): A [torch
+ generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose
+ between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a
+ [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ # For more information on the sampling method you can take a look at Algorithm 2 of
+ # the official paper: https://arxiv.org/pdf/2202.09778.pdf
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+ for t in self.progress_bar(self.scheduler.timesteps):
+ model_output = self.unet(image, t).sample
+
+ image = self.scheduler.step(model_output, t, image).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/repaint/__init__.py b/diffusers/pipelines/repaint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..16bc86d1cedf6243fb92f7ba331b5a6188133298
--- /dev/null
+++ b/diffusers/pipelines/repaint/__init__.py
@@ -0,0 +1 @@
+from .pipeline_repaint import RePaintPipeline
diff --git a/diffusers/pipelines/repaint/pipeline_repaint.py b/diffusers/pipelines/repaint/pipeline_repaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..7af88f62755983ce41f4566a3a33a0e624d5e94f
--- /dev/null
+++ b/diffusers/pipelines/repaint/pipeline_repaint.py
@@ -0,0 +1,140 @@
+# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+import PIL
+from tqdm.auto import tqdm
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import RePaintScheduler
+
+
+def _preprocess_image(image: PIL.Image.Image):
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+ return image
+
+
+def _preprocess_mask(mask: PIL.Image.Image):
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+class RePaintPipeline(DiffusionPipeline):
+ unet: UNet2DModel
+ scheduler: RePaintScheduler
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ original_image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+ num_inference_steps: int = 250,
+ eta: float = 0.0,
+ jump_length: int = 10,
+ jump_n_sample: int = 10,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ original_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ The original image to inpaint on.
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ The mask_image where 0.0 values define which part of the original image to inpaint (change).
+ num_inference_steps (`int`, *optional*, defaults to 1000):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ eta (`float`):
+ The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM
+ and 1.0 is DDPM scheduler respectively.
+ jump_length (`int`, *optional*, defaults to 10):
+ The number of steps taken forward in time before going backward in time for a single jump ("j" in
+ RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
+ jump_n_sample (`int`, *optional*, defaults to 10):
+ The number of times we will make forward time jump for a given chosen time sample. Take a look at
+ Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if not isinstance(original_image, torch.FloatTensor):
+ original_image = _preprocess_image(original_image)
+ original_image = original_image.to(self.device)
+ if not isinstance(mask_image, torch.FloatTensor):
+ mask_image = _preprocess_mask(mask_image)
+ mask_image = mask_image.to(self.device)
+
+ # sample gaussian noise to begin the loop
+ image = torch.randn(
+ original_image.shape,
+ generator=generator,
+ device=self.device,
+ )
+ image = image.to(self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
+ self.scheduler.eta = eta
+
+ t_last = self.scheduler.timesteps[0] + 1
+ for i, t in enumerate(tqdm(self.scheduler.timesteps)):
+ if t < t_last:
+ # predict the noise residual
+ model_output = self.unet(image, t).sample
+ # compute previous image: x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample
+
+ else:
+ # compute the reverse: x_t-1 -> x_t
+ image = self.scheduler.undo_step(image, t_last, generator)
+ t_last = t
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/score_sde_ve/__init__.py b/diffusers/pipelines/score_sde_ve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..000d61f6e9b183728cb6fc137e7180cac3a616df
--- /dev/null
+++ b/diffusers/pipelines/score_sde_ve/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_score_sde_ve import ScoreSdeVePipeline
diff --git a/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb6a5d3cbd40aedfdc684f84d6b1c65fcfd3670
--- /dev/null
+++ b/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
@@ -0,0 +1,101 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import ScoreSdeVeScheduler
+
+
+class ScoreSdeVePipeline(DiffusionPipeline):
+ r"""
+ Parameters:
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]):
+ The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image.
+ """
+ unet: UNet2DModel
+ scheduler: ScoreSdeVeScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 2000,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ img_size = self.unet.config.sample_size
+ shape = (batch_size, 3, img_size, img_size)
+
+ model = self.unet
+
+ sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma
+ sample = sample.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+ self.scheduler.set_sigmas(num_inference_steps)
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
+
+ # correction step
+ for _ in range(self.scheduler.config.correct_steps):
+ model_output = self.unet(sample, sigma_t).sample
+ sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
+
+ # prediction step
+ model_output = model(sample, sigma_t).sample
+ output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
+
+ sample, sample_mean = output.prev_sample, output.prev_sample_mean
+
+ sample = sample_mean.clamp(0, 1)
+ sample = sample.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ sample = self.numpy_to_pil(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return ImagePipelineOutput(images=sample)
diff --git a/diffusers/pipelines/stable_diffusion/README.md b/diffusers/pipelines/stable_diffusion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..be4c5d942b2e313ebfac5acc22764de8bae48bf5
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/README.md
@@ -0,0 +1,176 @@
+# Stable Diffusion
+
+## Overview
+
+Stable Diffusion was proposed in [Stable Diffusion Announcement](https://stability.ai/blog/stable-diffusion-announcement) by Patrick Esser and Robin Rombach and the Stability AI team.
+
+The summary of the model is the following:
+
+*Stable Diffusion is a text-to-image model that will empower billions of people to create stunning art within seconds. It is a breakthrough in speed and quality meaning that it can run on consumer GPUs. You can see some of the amazing output that has been created by this model without pre or post-processing on this page. The model itself builds upon the work of the team at CompVis and Runway in their widely used latent diffusion model combined with insights from the conditional diffusion models by our lead generative AI developer Katherine Crowson, Dall-E 2 by Open AI, Imagen by Google Brain and many others. We are delighted that AI media generation is a cooperative field and hope it can continue this way to bring the gift of creativity to all.*
+
+## Tips:
+
+- Stable Diffusion has the same architecture as [Latent Diffusion](https://arxiv.org/abs/2112.10752) but uses a frozen CLIP Text Encoder instead of training the text encoder jointly with the diffusion model.
+- An in-detail explanation of the Stable Diffusion model can be found under [Stable Diffusion with 🧨 Diffusers](https://huggingface.co/blog/stable_diffusion).
+- If you don't want to rely on the Hugging Face Hub and having to pass a authentication token, you can
+download the weights with `git lfs install; git clone https://huggingface.co/runwayml/stable-diffusion-v1-5` and instead pass the local path to the cloned folder to `from_pretrained` as shown below.
+- Stable Diffusion can work with a variety of different samplers as is shown below.
+
+## Available Pipelines:
+
+| Pipeline | Tasks | Colab
+|---|---|:---:|
+| [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
+| [pipeline_stable_diffusion_img2img](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
+| [pipeline_stable_diffusion_inpaint](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | *Text-Guided Image Inpainting* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
+
+## Examples:
+
+### Using Stable Diffusion without being logged into the Hub.
+
+If you want to download the model weights using a single Python line, you need to be logged in via `huggingface-cli login`.
+
+```python
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+```
+
+This however can make it difficult to build applications on top of `diffusers` as you will always have to pass the token around. A potential way to solve this issue is by downloading the weights to a local path `"./stable-diffusion-v1-5"`:
+
+```
+git lfs install
+git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
+```
+
+and simply passing the local path to `from_pretrained`:
+
+```python
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
+```
+
+### Text-to-Image with default PLMS scheduler
+
+```python
+# make sure you're logged in with `huggingface-cli login`
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+pipe = pipe.to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).sample[0]
+
+image.save("astronaut_rides_horse.png")
+```
+
+### Text-to-Image with DDIM scheduler
+
+```python
+# make sure you're logged in with `huggingface-cli login`
+from diffusers import StableDiffusionPipeline, DDIMScheduler
+
+scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ scheduler=scheduler,
+).to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).sample[0]
+
+image.save("astronaut_rides_horse.png")
+```
+
+### Text-to-Image with K-LMS scheduler
+
+```python
+# make sure you're logged in with `huggingface-cli login`
+from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
+
+lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ scheduler=lms,
+).to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).sample[0]
+
+image.save("astronaut_rides_horse.png")
+```
+
+### CycleDiffusion using Stable Diffusion and DDIM scheduler
+
+```python
+import requests
+import torch
+from PIL import Image
+from io import BytesIO
+
+from diffusers import CycleDiffusionPipeline, DDIMScheduler
+
+
+# load the scheduler. CycleDiffusion only supports stochastic schedulers.
+
+# load the pipeline
+# make sure you're logged in with `huggingface-cli login`
+model_id_or_path = "CompVis/stable-diffusion-v1-4"
+scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
+pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
+
+# let's download an initial image
+url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png"
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = init_image.resize((512, 512))
+init_image.save("horse.png")
+
+# let's specify a prompt
+source_prompt = "An astronaut riding a horse"
+prompt = "An astronaut riding an elephant"
+
+# call the pipeline
+image = pipe(
+ prompt=prompt,
+ source_prompt=source_prompt,
+ image=init_image,
+ num_inference_steps=100,
+ eta=0.1,
+ strength=0.8,
+ guidance_scale=2,
+ source_guidance_scale=1,
+).images[0]
+
+image.save("horse_to_elephant.png")
+
+# let's try another example
+# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion
+url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = init_image.resize((512, 512))
+init_image.save("black.png")
+
+source_prompt = "A black colored car"
+prompt = "A blue colored car"
+
+# call the pipeline
+torch.manual_seed(0)
+image = pipe(
+ prompt=prompt,
+ source_prompt=source_prompt,
+ image=init_image,
+ num_inference_steps=100,
+ eta=0.1,
+ strength=0.85,
+ guidance_scale=3,
+ source_guidance_scale=1,
+).images[0]
+
+image.save("black_to_blue.png")
+```
diff --git a/diffusers/pipelines/stable_diffusion/__init__.py b/diffusers/pipelines/stable_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..80ac88e1f4ddd525830eefa9bcb7207554d20db6
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/__init__.py
@@ -0,0 +1,78 @@
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+
+import PIL
+from PIL import Image
+
+from ...utils import (
+ BaseOutput,
+ is_flax_available,
+ is_onnx_available,
+ is_torch_available,
+ is_transformers_available,
+ is_transformers_version,
+)
+
+
+@dataclass
+class StableDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, or `None` if safety checking could not be performed.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: Optional[List[bool]]
+
+
+if is_transformers_available() and is_torch_available():
+ from .pipeline_cycle_diffusion import CycleDiffusionPipeline
+ from .pipeline_stable_diffusion import StableDiffusionPipeline
+ from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
+ from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
+ from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
+ from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
+ from .safety_checker import StableDiffusionSafetyChecker
+
+if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
+ from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
+else:
+ from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
+
+if is_transformers_available() and is_onnx_available():
+ from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
+ from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
+ from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline
+ from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy
+
+if is_transformers_available() and is_flax_available():
+ import flax
+
+ @flax.struct.dataclass
+ class FlaxStableDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`np.ndarray`)
+ Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content.
+ """
+
+ images: np.ndarray
+ nsfw_content_detected: List[bool]
+
+ from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
+ from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
+ from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..a688a52a7a6ec65a5774dd6c6fe1ce1e9d66acab
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
@@ -0,0 +1,687 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from diffusers.utils import is_accelerate_available
+from packaging import version
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler
+from ...utils import PIL_INTERPOLATION, deprecate, logging
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def preprocess(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
+
+ if prev_timestep <= 0:
+ return clean_latents
+
+ # 2. compute alphas, betas
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
+ alpha_prod_t_prev = (
+ scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
+ )
+
+ variance = scheduler._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ # direction pointing to x_t
+ e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5)
+ dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t
+ noise = std_dev_t * torch.randn(
+ clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator
+ )
+ prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise
+
+ return prev_latents
+
+
+def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta):
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
+ alpha_prod_t_prev = (
+ scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
+ )
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
+
+ # 4. Clip "predicted x_0"
+ if scheduler.config.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = scheduler._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred
+
+ noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / (
+ variance ** (0.5) * eta
+ )
+ return noise
+
+
+class CycleDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image to image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ if isinstance(self.unet.config.attention_head_dim, int):
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ else:
+ # if `attention_head_dim` is a list, take the smallest head size
+ slice_size = min(self.unet.config.attention_head_dim)
+
+ self.unet.set_attention_slice(slice_size)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
+ # fix by only offloading self.safety_checker for now
+ cpu_offload(self.safety_checker.vision_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
+ def check_inputs(self, prompt, strength, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ image = image.to(device=device, dtype=dtype)
+ init_latent_dist = self.vae.encode(image).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = 0.18215 * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
+
+ # add noise to latents using the timestep
+ noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ clean_latents = init_latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents, clean_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ source_prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ source_guidance_scale: Optional[float] = 1,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.1,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ source_guidance_scale (`float`, *optional*, defaults to 1):
+ Guidance scale for the source prompt. This is useful to control the amount of influence the source
+ prompt for encoding.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.1):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ message = "Please use `image` instead of `init_image`."
+ init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
+ image = init_image or image
+
+ # 1. Check inputs
+ self.check_inputs(prompt, strength, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, None)
+ source_text_embeddings = self._encode_prompt(
+ source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None
+ )
+
+ # 4. Preprocess image
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess(image)
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents, clean_latents = self.prepare_latents(
+ image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
+ )
+ source_latents = latents
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ generator = extra_step_kwargs.pop("generator", None)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2)
+ source_latent_model_input = torch.cat([source_latents] * 2)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
+
+ # predict the noise residual
+ concat_latent_model_input = torch.stack(
+ [
+ source_latent_model_input[0],
+ latent_model_input[0],
+ source_latent_model_input[1],
+ latent_model_input[1],
+ ],
+ dim=0,
+ )
+ concat_text_embeddings = torch.stack(
+ [
+ source_text_embeddings[0],
+ text_embeddings[0],
+ source_text_embeddings[1],
+ text_embeddings[1],
+ ],
+ dim=0,
+ )
+ concat_noise_pred = self.unet(
+ concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
+ ).sample
+
+ # perform guidance
+ (
+ source_noise_pred_uncond,
+ noise_pred_uncond,
+ source_noise_pred_text,
+ noise_pred_text,
+ ) = concat_noise_pred.chunk(4, dim=0)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
+ source_noise_pred_text - source_noise_pred_uncond
+ )
+
+ # Sample source_latents from the posterior distribution.
+ prev_source_latents = posterior_sample(
+ self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
+ )
+ # Compute noise.
+ noise = compute_noise(
+ self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
+ )
+ source_latents = prev_source_latents
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
+ ).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 9. Post-processing
+ image = self.decode_latents(latents)
+
+ # 10. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 11. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..23148dcfe2e32d71e597305efb7018a758b39c76
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
@@ -0,0 +1,429 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from functools import partial
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict
+from flax.jax_utils import unreplicate
+from flax.training.common_utils import shard
+from packaging import version
+from PIL import Image
+from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
+
+from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
+from ...pipeline_flax_utils import FlaxDiffusionPipeline
+from ...schedulers import (
+ FlaxDDIMScheduler,
+ FlaxDPMSolverMultistepScheduler,
+ FlaxLMSDiscreteScheduler,
+ FlaxPNDMScheduler,
+)
+from ...utils import deprecate, logging
+from . import FlaxStableDiffusionPipelineOutput
+from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`FlaxAutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`FlaxCLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
+ [`FlaxDPMSolverMultistepScheduler`].
+ safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: FlaxAutoencoderKL,
+ text_encoder: FlaxCLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: FlaxUNet2DConditionModel,
+ scheduler: Union[
+ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
+ ],
+ safety_checker: FlaxStableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ super().__init__()
+ self.dtype = dtype
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ def prepare_inputs(self, prompt: Union[str, List[str]]):
+ if not isinstance(prompt, (str, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ return text_input.input_ids
+
+ def _get_has_nsfw_concepts(self, features, params):
+ has_nsfw_concepts = self.safety_checker(features, params)
+ return has_nsfw_concepts
+
+ def _run_safety_checker(self, images, safety_model_params, jit=False):
+ # safety_model_params should already be replicated when jit is True
+ pil_images = [Image.fromarray(image) for image in images]
+ features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
+
+ if jit:
+ features = shard(features)
+ has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
+ has_nsfw_concepts = unshard(has_nsfw_concepts)
+ safety_model_params = unreplicate(safety_model_params)
+ else:
+ has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
+
+ images_was_copied = False
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+ if has_nsfw_concept:
+ if not images_was_copied:
+ images_was_copied = True
+ images = images.copy()
+
+ images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
+
+ if any(has_nsfw_concepts):
+ warnings.warn(
+ "Potential NSFW content was detected in one or more images. A black image will be returned"
+ " instead. Try again with a different prompt and/or seed."
+ )
+
+ return images, has_nsfw_concepts
+
+ def _generate(
+ self,
+ prompt_ids: jnp.array,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.PRNGKey,
+ num_inference_steps: int = 50,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ guidance_scale: float = 7.5,
+ latents: Optional[jnp.array] = None,
+ debug: bool = False,
+ neg_prompt_ids: jnp.array = None,
+ ):
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
+
+ # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
+ # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
+ batch_size = prompt_ids.shape[0]
+
+ max_length = prompt_ids.shape[-1]
+
+ if neg_prompt_ids is None:
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
+ ).input_ids
+ else:
+ uncond_input = neg_prompt_ids
+ uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
+ context = jnp.concatenate([uncond_embeddings, text_embeddings])
+
+ latents_shape = (
+ batch_size,
+ self.unet.in_channels,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if latents is None:
+ latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+
+ def loop_body(step, args):
+ latents, scheduler_state = args
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ latents_input = jnp.concatenate([latents] * 2)
+
+ t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
+ timestep = jnp.broadcast_to(t, latents_input.shape[0])
+
+ latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet.apply(
+ {"params": params["unet"]},
+ jnp.array(latents_input),
+ jnp.array(timestep, dtype=jnp.int32),
+ encoder_hidden_states=context,
+ ).sample
+ # perform guidance
+ noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
+ return latents, scheduler_state
+
+ scheduler_state = self.scheduler.set_timesteps(
+ params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
+ )
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ if debug:
+ # run with python for loop
+ for i in range(num_inference_steps):
+ latents, scheduler_state = loop_body(i, (latents, scheduler_state))
+ else:
+ latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
+
+ image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
+ return image
+
+ def __call__(
+ self,
+ prompt_ids: jnp.array,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.PRNGKey,
+ num_inference_steps: int = 50,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ guidance_scale: float = 7.5,
+ latents: jnp.array = None,
+ return_dict: bool = True,
+ jit: bool = False,
+ debug: bool = False,
+ neg_prompt_ids: jnp.array = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`jnp.array`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ jit (`bool`, defaults to `False`):
+ Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
+ exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
+ a plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ if jit:
+ images = _p_generate(
+ self,
+ prompt_ids,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ debug,
+ neg_prompt_ids,
+ )
+ else:
+ images = self._generate(
+ prompt_ids,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ debug,
+ neg_prompt_ids,
+ )
+
+ if self.safety_checker is not None:
+ safety_params = params["safety_checker"]
+ images_uint8_casted = (images * 255).round().astype("uint8")
+ num_devices, batch_size = images.shape[:2]
+
+ images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
+ images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
+ images = np.asarray(images)
+
+ # block images
+ if any(has_nsfw_concept):
+ for i, is_nsfw in enumerate(has_nsfw_concept):
+ if is_nsfw:
+ images[i] = np.asarray(images_uint8_casted[i])
+
+ images = images.reshape(num_devices, batch_size, height, width, 3)
+ else:
+ images = np.asarray(images)
+ has_nsfw_concept = False
+
+ if not return_dict:
+ return (images, has_nsfw_concept)
+
+ return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
+
+
+# TODO: maybe use a config dict instead of so many static argnums
+@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
+def _p_generate(
+ pipe,
+ prompt_ids,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ debug,
+ neg_prompt_ids,
+):
+ return pipe._generate(
+ prompt_ids,
+ params,
+ prng_seed,
+ num_inference_steps,
+ height,
+ width,
+ guidance_scale,
+ latents,
+ debug,
+ neg_prompt_ids,
+ )
+
+
+@partial(jax.pmap, static_broadcasted_argnums=(0,))
+def _p_get_has_nsfw_concepts(pipe, features, params):
+ return pipe._get_has_nsfw_concepts(features, params)
+
+
+def unshard(x: jnp.ndarray):
+ # einops.rearrange(x, 'd b ... -> (d b) ...')
+ num_devices, batch_size = x.shape[:2]
+ rest = x.shape[2:]
+ return x.reshape(num_devices * batch_size, *rest)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b9a8ff724a459b8204b9a4e6351c1bd3e45964f
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
@@ -0,0 +1,353 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+
+from transformers import CLIPFeatureExtractor, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import deprecate, logging
+from . import StableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__)
+
+
+class OnnxStableDiffusionPipeline(DiffusionPipeline):
+ vae_encoder: OnnxRuntimeModel
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+ safety_checker: OnnxRuntimeModel
+ feature_extractor: CLIPFeatureExtractor
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
+
+ if not np.array_equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
+ text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+ uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[np.random.RandomState] = None,
+ latents: Optional[np.ndarray] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if generator is None:
+ generator = np.random
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ text_embeddings = self._encode_prompt(
+ prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # get the initial random noise unless the user supplied it
+ latents_dtype = text_embeddings.dtype
+ latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
+ if latents is None:
+ latents = generator.randn(*latents_shape).astype(latents_dtype)
+ elif latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ latents = latents * np.float(self.scheduler.init_noise_sigma)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ timestep_dtype = next(
+ (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
+ )
+ timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
+ latent_model_input = latent_model_input.cpu().numpy()
+
+ # predict the noise residual
+ timestep = np.array([t], dtype=timestep_dtype)
+ noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings)
+ noise_pred = noise_pred[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ scheduler_output = self.scheduler.step(
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
+ )
+ latents = scheduler_output.prev_sample.numpy()
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ # image = self.vae_decoder(latent_sample=latents)[0]
+ # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
+ image = np.concatenate(
+ [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
+ )
+
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(
+ self.numpy_to_pil(image), return_tensors="np"
+ ).pixel_values.astype(image.dtype)
+
+ image, has_nsfw_concepts = self.safety_checker(clip_input=safety_checker_input, images=image)
+
+ # There will throw an error if use safety_checker batchsize>1
+ images, has_nsfw_concept = [], []
+ for i in range(image.shape[0]):
+ image_i, has_nsfw_concept_i = self.safety_checker(
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
+ )
+ images.append(image_i)
+ has_nsfw_concept.append(has_nsfw_concept_i[0])
+ image = np.concatenate(images)
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+
+class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline):
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`."
+ deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message)
+ super().__init__(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..2242d21b1d9147b61181cd43c59649dbafbdc598
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
@@ -0,0 +1,459 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from transformers import CLIPFeatureExtractor, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import PIL_INTERPOLATION, deprecate, logging
+from . import StableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def preprocess(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ return 2.0 * image - 1.0
+
+
+class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image to image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ vae_encoder: OnnxRuntimeModel
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+ safety_checker: OnnxRuntimeModel
+ feature_extractor: CLIPFeatureExtractor
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
+ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
+
+ if not np.array_equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
+ text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+ uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[np.ndarray, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[np.random.RandomState] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`np.random.RandomState`, *optional*):
+ A np.random.RandomState to make generation deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ message = "Please use `image` instead of `init_image`."
+ init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
+ image = init_image or image
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if generator is None:
+ generator = np.random
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess(image)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ text_embeddings = self._encode_prompt(
+ prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ latents_dtype = text_embeddings.dtype
+ image = image.astype(latents_dtype)
+ # encode the init image into latents and scale the latents
+ init_latents = self.vae_encoder(sample=image)[0]
+ init_latents = 0.18215 * init_latents
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = len(prompt) // init_latents.shape[0]
+ init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0)
+ elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
+ )
+ else:
+ init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0)
+
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ timesteps = self.scheduler.timesteps.numpy()[-init_timestep]
+ timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
+
+ # add noise to latents using the timesteps
+ noise = generator.randn(*init_latents.shape).astype(latents_dtype)
+ init_latents = self.scheduler.add_noise(
+ torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
+ )
+ init_latents = init_latents.numpy()
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ latents = init_latents
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:].numpy()
+
+ timestep_dtype = next(
+ (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
+ )
+ timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
+
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
+ latent_model_input = latent_model_input.cpu().numpy()
+
+ # predict the noise residual
+ timestep = np.array([t], dtype=timestep_dtype)
+ noise_pred = self.unet(
+ sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ scheduler_output = self.scheduler.step(
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
+ )
+ latents = scheduler_output.prev_sample.numpy()
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ # image = self.vae_decoder(latent_sample=latents)[0]
+ # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
+ image = np.concatenate(
+ [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
+ )
+
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(
+ self.numpy_to_pil(image), return_tensors="np"
+ ).pixel_values.astype(image.dtype)
+ # safety_checker does not support batched inputs yet
+ images, has_nsfw_concept = [], []
+ for i in range(image.shape[0]):
+ image_i, has_nsfw_concept_i = self.safety_checker(
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
+ )
+ images.append(image_i)
+ has_nsfw_concept.append(has_nsfw_concept_i[0])
+ image = np.concatenate(images)
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..930d61de99ccc2d9fb13a8fd4a0dbf408f0f9c63
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -0,0 +1,478 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from transformers import CLIPFeatureExtractor, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import PIL_INTERPOLATION, deprecate, logging
+from . import StableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+NUM_UNET_INPUT_CHANNELS = 9
+NUM_LATENT_CHANNELS = 4
+
+
+def prepare_mask_and_masked_image(image, mask, latents_shape):
+ image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = image.astype(np.float32) / 127.5 - 1.0
+
+ image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
+ masked_image = image * (image_mask < 127.5)
+
+ mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"])
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ return mask, masked_image
+
+
+class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ vae_encoder: OnnxRuntimeModel
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+ safety_checker: OnnxRuntimeModel
+ feature_extractor: CLIPFeatureExtractor
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+ logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
+ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
+
+ if not np.array_equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
+ text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+ uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: PIL.Image.Image,
+ mask_image: PIL.Image.Image,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[np.random.RandomState] = None,
+ latents: Optional[np.ndarray] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`np.random.RandomState`, *optional*):
+ A np.random.RandomState to make generation deterministic.
+ latents (`np.ndarray`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if generator is None:
+ generator = np.random
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ text_embeddings = self._encode_prompt(
+ prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ num_channels_latents = NUM_LATENT_CHANNELS
+ latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ latents = generator.randn(*latents_shape).astype(latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+
+ # prepare mask and masked_image
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:])
+ mask = mask.astype(latents.dtype)
+ masked_image = masked_image.astype(latents.dtype)
+
+ masked_image_latents = self.vae_encoder(sample=masked_image)[0]
+ masked_image_latents = 0.18215 * masked_image_latents
+
+ # duplicate mask and masked_image_latents for each generation per prompt
+ mask = mask.repeat(batch_size * num_images_per_prompt, 0)
+ masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0)
+
+ mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+
+ unet_input_channels = NUM_UNET_INPUT_CHANNELS
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels:
+ raise ValueError(
+ "Incorrect configuration settings! The config of `pipeline.unet` expects"
+ f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * np.float(self.scheduler.init_noise_sigma)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ timestep_dtype = next(
+ (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
+ )
+ timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ # concat latents, mask, masked_image_latnets in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
+ latent_model_input = latent_model_input.cpu().numpy()
+ latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
+
+ # predict the noise residual
+ timestep = np.array([t], dtype=timestep_dtype)
+ noise_pred = self.unet(
+ sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ scheduler_output = self.scheduler.step(
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
+ )
+ latents = scheduler_output.prev_sample.numpy()
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ # image = self.vae_decoder(latent_sample=latents)[0]
+ # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
+ image = np.concatenate(
+ [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
+ )
+
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(
+ self.numpy_to_pil(image), return_tensors="np"
+ ).pixel_values.astype(image.dtype)
+ # safety_checker does not support batched inputs yet
+ images, has_nsfw_concept = [], []
+ for i in range(image.shape[0]):
+ image_i, has_nsfw_concept_i = self.safety_checker(
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
+ )
+ images.append(image_i)
+ has_nsfw_concept.append(has_nsfw_concept_i[0])
+ image = np.concatenate(images)
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..84e85e51cca21d5bdaead87e77fc184a65d9e9ab
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
@@ -0,0 +1,461 @@
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from transformers import CLIPFeatureExtractor, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...onnx_utils import OnnxRuntimeModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import deprecate, logging
+from . import StableDiffusionPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def preprocess(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask, scale_factor=8):
+ mask = mask.convert("L")
+ w, h = mask.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST)
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
+ mask = 1 - mask # repaint white, keep black
+ return mask
+
+
+class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to
+ provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ vae_encoder: OnnxRuntimeModel
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+ safety_checker: OnnxRuntimeModel
+ feature_extractor: CLIPFeatureExtractor
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
+ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
+
+ if not np.array_equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
+ text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+ uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[np.ndarray, PIL.Image.Image],
+ mask_image: Union[np.ndarray, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[np.random.RandomState] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`nd.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. This is the image whose masked region will be inpainted.
+ mask_image (`nd.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (?) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`np.random.RandomState`, *optional*):
+ A np.random.RandomState to make generation deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ message = "Please use `image` instead of `init_image`."
+ init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
+ image = init_image or image
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if generator is None:
+ generator = np.random
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess(image)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ text_embeddings = self._encode_prompt(
+ prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ latents_dtype = text_embeddings.dtype
+ image = image.astype(latents_dtype)
+
+ # encode the init image into latents and scale the latents
+ init_latents = self.vae_encoder(sample=image)[0]
+ init_latents = 0.18215 * init_latents
+
+ # Expand init_latents for batch_size and num_images_per_prompt
+ init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0)
+ init_latents_orig = init_latents
+
+ # preprocess mask
+ if not isinstance(mask_image, np.ndarray):
+ mask_image = preprocess_mask(mask_image, 8)
+ mask_image = mask_image.astype(latents_dtype)
+ mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0)
+
+ # check sizes
+ if not mask.shape == init_latents.shape:
+ raise ValueError("The mask and image should be the same size!")
+
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ timesteps = self.scheduler.timesteps.numpy()[-init_timestep]
+ timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
+
+ # add noise to latents using the timesteps
+ noise = generator.randn(*init_latents.shape).astype(latents_dtype)
+ init_latents = self.scheduler.add_noise(
+ torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
+ )
+ init_latents = init_latents.numpy()
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to ? in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ latents = init_latents
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:].numpy()
+
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
+ ).prev_sample
+
+ latents = latents.numpy()
+
+ init_latents_proper = self.scheduler.add_noise(
+ torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t]))
+ )
+
+ init_latents_proper = init_latents_proper.numpy()
+
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ # image = self.vae_decoder(latent_sample=latents)[0]
+ # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
+ image = np.concatenate(
+ [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
+ )
+
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(
+ self.numpy_to_pil(image), return_tensors="np"
+ ).pixel_values.astype(image.dtype)
+ # There will throw an error if use safety_checker batchsize>1
+ images, has_nsfw_concept = [], []
+ for i in range(image.shape[0]):
+ image_i, has_nsfw_concept_i = self.safety_checker(
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
+ )
+ images.append(image_i)
+ has_nsfw_concept.append(has_nsfw_concept_i[0])
+ image = np.concatenate(images)
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3a8703f3ea4070337e5f55be5199277c00413ab
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -0,0 +1,578 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+
+from diffusers.utils import is_accelerate_available
+from packaging import version
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import deprecate, logging
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class StableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ if isinstance(self.unet.config.attention_head_dim, int):
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ else:
+ # if `attention_head_dim` is a list, take the smallest head size
+ slice_size = min(self.unet.config.attention_head_dim)
+
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
+ # fix by only offloading self.safety_checker for now
+ cpu_offload(self.safety_checker.vision_model, device)
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 10. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
new file mode 100644
index 0000000000000000000000000000000000000000..d77e71653078dfb206f267f889334d1ed7b7da8b
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
@@ -0,0 +1,461 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+
+import PIL
+from diffusers.utils import is_accelerate_available
+from packaging import version
+from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import deprecate, logging
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class StableDiffusionImageVariationPipeline(DiffusionPipeline):
+ r"""
+ Pipeline to generate variations from an input image using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ image_encoder: CLIPVisionModelWithProjection,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warn(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ image_encoder=image_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ if isinstance(self.unet.config.attention_head_dim, int):
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ else:
+ # if `attention_head_dim` is a list, take the smallest head size
+ slice_size = min(self.unet.config.attention_head_dim)
+
+ self.unet.set_attention_slice(slice_size)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.image_encoder, self.vae, self.safety_checker]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeddings = self.image_encoder(image).image_embeds
+ image_embeddings = image_embeddings.unsqueeze(1)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ uncond_embeddings = torch.zeros_like(image_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
+
+ return image_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, image, height, width, callback_steps):
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
+ The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
+ configuration of
+ [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
+ `CLIPFeatureExtractor`
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(image, height, width, callback_steps)
+
+ # 2. Define call parameters
+ if isinstance(image, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(image, list):
+ batch_size = len(image)
+ else:
+ batch_size = image.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input image
+ image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
+
+ # 10. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..933f59c3bd3291e9445e26b707d16c0e25c5ff67
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -0,0 +1,608 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from diffusers.utils import is_accelerate_available
+from packaging import version
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import PIL_INTERPOLATION, deprecate, logging
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def preprocess(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image to image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ if isinstance(self.unet.config.attention_head_dim, int):
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ else:
+ # if `attention_head_dim` is a list, take the smallest head size
+ slice_size = min(self.unet.config.attention_head_dim)
+
+ self.unet.set_attention_slice(slice_size)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
+ # fix by only offloading self.safety_checker for now
+ cpu_offload(self.safety_checker.vision_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, strength, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ image = image.to(device=device, dtype=dtype)
+ init_latent_dist = self.vae.encode(image).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = 0.18215 * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ message = "Please use `image` instead of `init_image`."
+ init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
+ image = init_image or image
+
+ # 1. Check inputs
+ self.check_inputs(prompt, strength, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Preprocess image
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess(image)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 9. Post-processing
+ image = self.decode_latents(latents)
+
+ # 10. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 11. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc416f57d3e0ee09331b763ef91c01acb3ae4e57
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -0,0 +1,725 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from diffusers.utils import is_accelerate_available
+from packaging import version
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import deprecate, logging
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def prepare_mask_and_masked_image(image, mask):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
+
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ if isinstance(image, PIL.Image.Image):
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+ if isinstance(mask, PIL.Image.Image):
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ return mask, masked_image
+
+
+class StableDiffusionInpaintPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration"
+ " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
+ " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
+ " Hub, it would be very nice if you could open a Pull request for the"
+ " `scheduler/scheduler_config.json` file"
+ )
+ deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["skip_prk_steps"] = True
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ if isinstance(self.unet.config.attention_head_dim, int):
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ else:
+ # if `attention_head_dim` is a list, take the smallest head size
+ slice_size = min(self.unet.config.attention_head_dim)
+
+ self.unet.set_attention_slice(slice_size)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
+ # fix by only offloading self.safety_checker for now
+ cpu_offload(self.safety_checker.vision_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
+ masked_image_latents = 0.18215 * masked_image_latents
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ mask = mask.repeat(batch_size, 1, 1, 1)
+ masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return mask, masked_image_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Preprocess mask and image
+ if isinstance(image, PIL.Image.Image) and isinstance(mask_image, PIL.Image.Image):
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
+ else:
+ mask = mask_image
+ masked_image = image * (mask < 0.5)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 11. Post-processing
+ image = self.decode_latents(latents)
+
+ # 12. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 13. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..60d52eaa1ab4bc380e282067db6bf624589289cd
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
@@ -0,0 +1,623 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from diffusers.utils import is_accelerate_available
+from packaging import version
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import PIL_INTERPOLATION, deprecate, logging
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__)
+
+
+def preprocess_image(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask, scale_factor=8):
+ mask = mask.convert("L")
+ w, h = mask.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
+ mask = 1 - mask # repaint white, keep black
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ if isinstance(self.unet.config.attention_head_dim, int):
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ else:
+ # if `attention_head_dim` is a list, take the smallest head size
+ slice_size = min(self.unet.config.attention_head_dim)
+
+ self.unet.set_attention_slice(slice_size)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
+ # fix by only offloading self.safety_checker for now
+ cpu_offload(self.safety_checker.vision_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
+ def check_inputs(self, prompt, strength, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator):
+ image = image.to(device=self.device, dtype=dtype)
+ init_latent_dist = self.vae.encode(image).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = 0.18215 * init_latents
+
+ # Expand init_latents for batch_size and num_images_per_prompt
+ init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
+ init_latents_orig = init_latents
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+ return latents, init_latents_orig, noise
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. This is the image whose masked region will be inpainted.
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to
+ that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ message = "Please use `image` instead of `init_image`."
+ init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
+ image = init_image or image
+
+ # 1. Check inputs
+ self.check_inputs(prompt, strength, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Preprocess image and mask
+ if not isinstance(image, torch.FloatTensor):
+ image = preprocess_image(image)
+
+ if not isinstance(mask_image, torch.FloatTensor):
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ # encode the init image into latents and scale the latents
+ latents, init_latents_orig, noise = self.prepare_latents(
+ image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
+ )
+
+ # 7. Prepare mask latent
+ mask = mask_image.to(device=self.device, dtype=latents.dtype)
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+ # masking
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
+
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 10. Post-processing
+ image = self.decode_latents(latents)
+
+ # 11. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 12. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
new file mode 100644
index 0000000000000000000000000000000000000000..72981aebe18478b320a7d397924b925c6dd6ef5e
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
@@ -0,0 +1,535 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from diffusers.utils import is_accelerate_available
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def preprocess(image):
+ # resize to multiple of 64
+ width, height = image.size
+ width = width - width % 64
+ height = height - height % 64
+ image = image.resize((width, height))
+
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+ return image
+
+
+class StableDiffusionUpscalePipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image super-resolution using Stable Diffusion 2.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ low_res_scheduler ([`SchedulerMixin`]):
+ A scheduler used to add initial noise to the low res conditioning image. It must be an instance of
+ [`DDPMScheduler`].
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ low_res_scheduler: DDPMScheduler,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ max_noise_level: int = 350,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ low_res_scheduler=low_res_scheduler,
+ scheduler=scheduler,
+ )
+ self.register_to_config(max_noise_level=max_noise_level)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ if isinstance(self.unet.config.attention_head_dim, int):
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ else:
+ # if `attention_head_dim` is a list, take the smallest head size
+ slice_size = min(self.unet.config.attention_head_dim)
+
+ self.unet.set_attention_slice(slice_size)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
+ def decode_latents(self, latents):
+ latents = 1 / 0.08333 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def check_inputs(self, prompt, image, noise_level, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
+ )
+
+ # verify batch size of prompt and image are same if image is a list or tensor
+ if isinstance(image, list) or isinstance(image, torch.Tensor):
+ if isinstance(prompt, str):
+ batch_size = 1
+ else:
+ batch_size = len(prompt)
+ if isinstance(image, list):
+ image_batch_size = len(image)
+ else:
+ image_batch_size = image.shape[0]
+ if batch_size != image_batch_size:
+ raise ValueError(
+ f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
+ " Please make sure that passed `prompt` matches the batch size of `image`."
+ )
+
+ # check noise level
+ if noise_level > self.config.max_noise_level:
+ raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height, width)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
+ num_inference_steps: int = 75,
+ guidance_scale: float = 9.0,
+ noise_level: int = 20,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
+ `Image`, or tensor representing an image batch which will be upscaled. *
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # 1. Check inputs
+ self.check_inputs(prompt, image, noise_level, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Preprocess image
+ image = [image] if isinstance(image, PIL.Image.Image) else image
+ if isinstance(image, list):
+ image = [preprocess(img) for img in image]
+ image = torch.cat(image, dim=0)
+ image = image.to(dtype=text_embeddings.dtype, device=device)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Add noise to image
+ noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ noise = torch.randn(image.shape, generator=generator, device="cpu", dtype=text_embeddings.dtype).to(device)
+ else:
+ noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
+ image = self.low_res_scheduler.add_noise(image, noise, noise_level)
+ image = torch.cat([image] * 2) if do_classifier_free_guidance else image
+ noise_level = torch.cat([noise_level] * 2) if do_classifier_free_guidance else noise_level
+
+ # 6. Prepare latent variables
+ height, width = image.shape[2:]
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Check that sizes of image and latents match
+ num_channels_image = image.shape[1]
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_image`: {num_channels_image} "
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ " `pipeline.unet` or your `image` input."
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = torch.cat([latent_model_input, image], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 10. Post-processing
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ self.vae.to(dtype=torch.float32)
+ image = self.decode_latents(latents.float())
+
+ # 11. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/stable_diffusion/safety_checker.py b/diffusers/pipelines/stable_diffusion/safety_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..1476c1ede62c6f2189c9025598ddab02169c5f69
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/safety_checker.py
@@ -0,0 +1,124 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
+
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def cosine_distance(image_embeds, text_embeds):
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
+
+
+class StableDiffusionSafetyChecker(PreTrainedModel):
+ config_class = CLIPConfig
+
+ _no_split_modules = ["CLIPEncoderLayer"]
+
+ def __init__(self, config: CLIPConfig):
+ super().__init__(config)
+
+ self.vision_model = CLIPVisionModel(config.vision_config)
+ self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
+
+ self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
+ self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
+
+ self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
+ self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
+
+ @torch.no_grad()
+ def forward(self, clip_input, images):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
+
+ result = []
+ batch_size = image_embeds.shape[0]
+ for i in range(batch_size):
+ result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
+
+ # increase this value to create a stronger `nfsw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ for concept_idx in range(len(special_cos_dist[0])):
+ concept_cos = special_cos_dist[i][concept_idx]
+ concept_threshold = self.special_care_embeds_weights[concept_idx].item()
+ result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["special_scores"][concept_idx] > 0:
+ result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
+ adjustment = 0.01
+
+ for concept_idx in range(len(cos_dist[0])):
+ concept_cos = cos_dist[i][concept_idx]
+ concept_threshold = self.concept_embeds_weights[concept_idx].item()
+ result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["concept_scores"][concept_idx] > 0:
+ result_img["bad_concepts"].append(concept_idx)
+
+ result.append(result_img)
+
+ # has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
+ has_nsfw_concepts = [False]
+
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+ if has_nsfw_concept:
+ images[idx] = np.zeros(images[idx].shape) # black image
+
+ if any(has_nsfw_concepts):
+ logger.warning(
+ "Potential NSFW content was detected in one or more images. A black image will be returned instead."
+ " Try again with a different prompt and/or seed."
+ )
+
+ return images, has_nsfw_concepts
+
+ @torch.no_grad()
+ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds)
+
+ # increase this value to create a stronger `nsfw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
+ # special_scores = special_scores.round(decimals=3)
+ special_care = torch.any(special_scores > 0, dim=1)
+ special_adjustment = special_care * 0.01
+ special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
+
+ concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
+ # concept_scores = concept_scores.round(decimals=3)
+ has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
+
+ images[has_nsfw_concepts] = 0.0 # black image
+
+ return images, has_nsfw_concepts
diff --git a/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1f669d22b76a44a5fbd523e6cbc61167cb12332
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
@@ -0,0 +1,112 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+import jax
+import jax.numpy as jnp
+from flax import linen as nn
+from flax.core.frozen_dict import FrozenDict
+from transformers import CLIPConfig, FlaxPreTrainedModel
+from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
+
+
+def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
+ norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
+ norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T
+ return jnp.matmul(norm_emb_1, norm_emb_2.T)
+
+
+class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
+ config: CLIPConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.vision_model = FlaxCLIPVisionModule(self.config.vision_config)
+ self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
+
+ self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim))
+ self.special_care_embeds = self.param(
+ "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim)
+ )
+
+ self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
+ self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,))
+
+ def __call__(self, clip_input):
+ pooled_output = self.vision_model(clip_input)[1]
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds)
+ cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds)
+
+ # increase this value to create a stronger `nfsw` filter
+ # at the cost of increasing the possibility of filtering benign image inputs
+ adjustment = 0.0
+
+ special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment
+ special_scores = jnp.round(special_scores, 3)
+ is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True)
+ # Use a lower threshold if an image has any special care concept
+ special_adjustment = is_special_care * 0.01
+
+ concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment
+ concept_scores = jnp.round(concept_scores, 3)
+ has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1)
+
+ return has_nsfw_concepts
+
+
+class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
+ config_class = CLIPConfig
+ main_input_name = "clip_input"
+ module_class = FlaxStableDiffusionSafetyCheckerModule
+
+ def __init__(
+ self,
+ config: CLIPConfig,
+ input_shape: Optional[Tuple] = None,
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ if input_shape is None:
+ input_shape = (1, 224, 224, 3)
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensor
+ clip_input = jax.random.normal(rng, input_shape)
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(rngs, clip_input)["params"]
+
+ return random_params
+
+ def __call__(
+ self,
+ clip_input,
+ params: dict = None,
+ ):
+ clip_input = jnp.transpose(clip_input, (0, 2, 3, 1))
+
+ return self.module.apply(
+ {"params": params or self.params},
+ jnp.array(clip_input, dtype=jnp.float32),
+ rngs={},
+ )
diff --git a/diffusers/pipelines/stable_diffusion_safe/__init__.py b/diffusers/pipelines/stable_diffusion_safe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..59ff61fa3b5429dd40a76b4f9a10c31ddee62967
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion_safe/__init__.py
@@ -0,0 +1,72 @@
+from dataclasses import dataclass
+from enum import Enum
+from typing import List, Optional, Union
+
+import numpy as np
+
+import PIL
+from PIL import Image
+
+from ...utils import BaseOutput, is_torch_available, is_transformers_available
+
+
+@dataclass
+class SafetyConfig(object):
+ WEAK = {
+ "sld_warmup_steps": 15,
+ "sld_guidance_scale": 20,
+ "sld_threshold": 0.0,
+ "sld_momentum_scale": 0.0,
+ "sld_mom_beta": 0.0,
+ }
+ MEDIUM = {
+ "sld_warmup_steps": 10,
+ "sld_guidance_scale": 1000,
+ "sld_threshold": 0.01,
+ "sld_momentum_scale": 0.3,
+ "sld_mom_beta": 0.4,
+ }
+ STRONG = {
+ "sld_warmup_steps": 7,
+ "sld_guidance_scale": 2000,
+ "sld_threshold": 0.025,
+ "sld_momentum_scale": 0.5,
+ "sld_mom_beta": 0.7,
+ }
+ MAX = {
+ "sld_warmup_steps": 0,
+ "sld_guidance_scale": 5000,
+ "sld_threshold": 1.0,
+ "sld_momentum_scale": 0.5,
+ "sld_mom_beta": 0.7,
+ }
+
+
+@dataclass
+class StableDiffusionSafePipelineOutput(BaseOutput):
+ """
+ Output class for Safe Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, or `None` if safety checking could not be performed.
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work"
+ (nsfw) content, or `None` if no safety check was performed or no images were flagged.
+ applied_safety_concept (`str`)
+ The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: Optional[List[bool]]
+ unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]]
+ applied_safety_concept: Optional[str]
+
+
+if is_transformers_available() and is_torch_available():
+ from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe
+ from .safety_checker import SafeStableDiffusionSafetyChecker
diff --git a/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cb0f2c03daf1ca284c5a57b928de9f922b621c5
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
@@ -0,0 +1,746 @@
+import inspect
+import warnings
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+
+from packaging import version
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import deprecate, is_accelerate_available, logging
+from . import StableDiffusionSafePipelineOutput
+from .safety_checker import SafeStableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class StableDiffusionPipelineSafe(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Safe Latent Diffusion.
+
+ The implementation is based on the [`StableDiffusionPipeline`]
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ ],
+ safety_checker: SafeStableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+ safety_concept: Optional[str] = (
+ "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity,"
+ " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child"
+ " abuse, brutality, cruelty"
+ )
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self._safety_text_concept = safety_concept
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ @property
+ def safety_concept(self):
+ r"""
+ Getter method for the safety concept used with SLD
+
+ Returns:
+ `str`: The text describing the safety concept
+ """
+ return self._safety_text_concept
+
+ @safety_concept.setter
+ def safety_concept(self, concept):
+ r"""
+ Setter method for the safety concept used with SLD
+
+ Args:
+ concept (`str`):
+ The text of the new safety concept
+ """
+ self._safety_text_concept = concept
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_sequential_cpu_offload(self):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device("cuda")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ enable_safety_guidance,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # Encode the safety concept text
+ if enable_safety_guidance:
+ safety_concept_input = self.tokenizer(
+ [self._safety_text_concept],
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0]
+
+ # duplicate safety embeddings for each generation per prompt, using mps friendly method
+ seq_len = safety_embeddings.shape[1]
+ safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1)
+ safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance + sld, we need to do three forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing three forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings, safety_embeddings])
+
+ else:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def run_safety_checker(self, image, device, dtype, enable_safety_guidance):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ flagged_images = None
+ if any(has_nsfw_concept):
+ logger.warning(
+ "Potential NSFW content was detected in one or more images. A black image will be returned"
+ " instead."
+ f" {'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'} "
+ )
+ flagged_images = np.zeros((2, *image.shape[1:]))
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concept):
+ if has_nsfw_concept:
+ flagged_images[idx] = image[idx]
+ image[idx] = np.zeros(image[idx].shape) # black image
+ else:
+ has_nsfw_concept = None
+ flagged_images = None
+ return image, has_nsfw_concept, flagged_images
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def perform_safety_guidance(
+ self,
+ enable_safety_guidance,
+ safety_momentum,
+ noise_guidance,
+ noise_pred_out,
+ i,
+ sld_guidance_scale,
+ sld_warmup_steps,
+ sld_threshold,
+ sld_momentum_scale,
+ sld_mom_beta,
+ ):
+ # Perform SLD guidance
+ if enable_safety_guidance:
+ if safety_momentum is None:
+ safety_momentum = torch.zeros_like(noise_guidance)
+ noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1]
+ noise_pred_safety_concept = noise_pred_out[2]
+
+ # Equation 6
+ scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0)
+
+ # Equation 6
+ safety_concept_scale = torch.where(
+ (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale
+ )
+
+ # Equation 4
+ noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale)
+
+ # Equation 7
+ noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
+
+ # Equation 8
+ safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
+
+ if i >= sld_warmup_steps: # Warmup
+ # Equation 3
+ noise_guidance = noise_guidance - noise_guidance_safety
+ return noise_guidance, safety_momentum
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ sld_guidance_scale: Optional[float] = 1000,
+ sld_warmup_steps: Optional[int] = 10,
+ sld_threshold: Optional[float] = 0.01,
+ sld_momentum_scale: Optional[float] = 0.3,
+ sld_mom_beta: Optional[float] = 0.4,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ sld_guidance_scale (`float`, *optional*, defaults to 1000):
+ Safe latent guidance as defined in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105).
+ `sld_guidance_scale` is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be
+ disabled.
+ sld_warmup_steps (`int`, *optional*, defaults to 10):
+ Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater than
+ `sld_warmup_steps`. `sld_warmup_steps` is defined as `delta` of [Safe Latent
+ Diffusion](https://arxiv.org/abs/2211.05105).
+ sld_threshold (`float`, *optional*, defaults to 0.01):
+ Threshold that separates the hyperplane between appropriate and inappropriate images. `sld_threshold`
+ is defined as `lamda` of Eq. 5 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105).
+ sld_momentum_scale (`float`, *optional*, defaults to 0.3):
+ Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0
+ momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller
+ than `sld_warmup_steps`. `sld_momentum_scale` is defined as `sm` of Eq. 7 in [Safe Latent
+ Diffusion](https://arxiv.org/abs/2211.05105).
+ sld_mom_beta (`float`, *optional*, defaults to 0.4):
+ Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous
+ momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller
+ than `sld_warmup_steps`. `sld_mom_beta` is defined as `beta m` of Eq. 8 in [Safe Latent
+ Diffusion](https://arxiv.org/abs/2211.05105).
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance
+ if not enable_safety_guidance:
+ warnings.warn("Safety checker disabled!")
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ safety_momentum = None
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents] * (3 if enable_safety_guidance else 2))
+ if do_classifier_free_guidance
+ else latents
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2))
+ noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
+
+ # default classifier free guidance
+ noise_guidance = noise_pred_text - noise_pred_uncond
+
+ # Perform SLD guidance
+ if enable_safety_guidance:
+ if safety_momentum is None:
+ safety_momentum = torch.zeros_like(noise_guidance)
+ noise_pred_safety_concept = noise_pred_out[2]
+
+ # Equation 6
+ scale = torch.clamp(
+ torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0
+ )
+
+ # Equation 6
+ safety_concept_scale = torch.where(
+ (noise_pred_text - noise_pred_safety_concept) >= sld_threshold,
+ torch.zeros_like(scale),
+ scale,
+ )
+
+ # Equation 4
+ noise_guidance_safety = torch.mul(
+ (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale
+ )
+
+ # Equation 7
+ noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
+
+ # Equation 8
+ safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
+
+ if i >= sld_warmup_steps: # Warmup
+ # Equation 3
+ noise_guidance = noise_guidance - noise_guidance_safety
+
+ noise_pred = noise_pred_uncond + guidance_scale * noise_guidance
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept, flagged_images = self.run_safety_checker(
+ image, device, text_embeddings.dtype, enable_safety_guidance
+ )
+
+ # 10. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+ if flagged_images is not None:
+ flagged_images = self.numpy_to_pil(flagged_images)
+
+ if not return_dict:
+ return (
+ image,
+ has_nsfw_concept,
+ self._safety_text_concept if enable_safety_guidance else None,
+ flagged_images,
+ )
+
+ return StableDiffusionSafePipelineOutput(
+ images=image,
+ nsfw_content_detected=has_nsfw_concept,
+ applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None,
+ unsafe_images=flagged_images,
+ )
diff --git a/diffusers/pipelines/stable_diffusion_safe/safety_checker.py b/diffusers/pipelines/stable_diffusion_safe/safety_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9dbf51e86440847646e168e5a50ebf835440f2a
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion_safe/safety_checker.py
@@ -0,0 +1,110 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+
+from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
+
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def cosine_distance(image_embeds, text_embeds):
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
+
+
+class SafeStableDiffusionSafetyChecker(PreTrainedModel):
+ config_class = CLIPConfig
+
+ _no_split_modules = ["CLIPEncoderLayer"]
+
+ def __init__(self, config: CLIPConfig):
+ super().__init__(config)
+
+ self.vision_model = CLIPVisionModel(config.vision_config)
+ self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
+
+ self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
+ self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
+
+ self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
+ self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
+
+ @torch.no_grad()
+ def forward(self, clip_input, images):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
+
+ result = []
+ batch_size = image_embeds.shape[0]
+ for i in range(batch_size):
+ result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
+
+ # increase this value to create a stronger `nfsw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ for concept_idx in range(len(special_cos_dist[0])):
+ concept_cos = special_cos_dist[i][concept_idx]
+ concept_threshold = self.special_care_embeds_weights[concept_idx].item()
+ result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["special_scores"][concept_idx] > 0:
+ result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
+ adjustment = 0.01
+
+ for concept_idx in range(len(cos_dist[0])):
+ concept_cos = cos_dist[i][concept_idx]
+ concept_threshold = self.concept_embeds_weights[concept_idx].item()
+ result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["concept_scores"][concept_idx] > 0:
+ result_img["bad_concepts"].append(concept_idx)
+
+ result.append(result_img)
+
+ has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
+
+ return images, has_nsfw_concepts
+
+ @torch.no_grad()
+ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds)
+
+ # increase this value to create a stronger `nsfw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
+ # special_scores = special_scores.round(decimals=3)
+ special_care = torch.any(special_scores > 0, dim=1)
+ special_adjustment = special_care * 0.01
+ special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
+
+ concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
+ # concept_scores = concept_scores.round(decimals=3)
+ has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
+
+ return images, has_nsfw_concepts
diff --git a/diffusers/pipelines/stochastic_karras_ve/__init__.py b/diffusers/pipelines/stochastic_karras_ve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..db2582043781130794e01b96b3e6beecbfe9f369
--- /dev/null
+++ b/diffusers/pipelines/stochastic_karras_ve/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_stochastic_karras_ve import KarrasVePipeline
diff --git a/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..739de8ebe620b5c99168720340a2485fa61d5a06
--- /dev/null
+++ b/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
@@ -0,0 +1,129 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import KarrasVeScheduler
+
+
+class KarrasVePipeline(DiffusionPipeline):
+ r"""
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
+ the VE column of Table 1 from [1] for reference.
+
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
+ https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://arxiv.org/abs/2011.13456
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`KarrasVeScheduler`]):
+ Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image.
+ """
+
+ # add type hints for linting
+ unet: UNet2DModel
+ scheduler: KarrasVeScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ img_size = self.unet.config.sample_size
+ shape = (batch_size, 3, img_size, img_size)
+
+ model = self.unet
+
+ # sample x_0 ~ N(0, sigma_0^2 * I)
+ sample = torch.randn(*shape) * self.scheduler.init_noise_sigma
+ sample = sample.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # here sigma_t == t_i from the paper
+ sigma = self.scheduler.schedule[t]
+ sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
+
+ # 1. Select temporarily increased noise level sigma_hat
+ # 2. Add new noise to move from sample_i to sample_hat
+ sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)
+
+ # 3. Predict the noise residual given the noise magnitude `sigma_hat`
+ # The model inputs and output are adjusted by following eq. (213) in [1].
+ model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample
+
+ # 4. Evaluate dx/dt at sigma_hat
+ # 5. Take Euler step from sigma to sigma_prev
+ step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat)
+
+ if sigma_prev != 0:
+ # 6. Apply 2nd order correction
+ # The model inputs and output are adjusted by following eq. (213) in [1].
+ model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample
+ step_output = self.scheduler.step_correct(
+ model_output,
+ sigma_hat,
+ sigma_prev,
+ sample_hat,
+ step_output.prev_sample,
+ step_output["derivative"],
+ )
+ sample = step_output.prev_sample
+
+ sample = (sample / 2 + 0.5).clamp(0, 1)
+ image = sample.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(sample)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/versatile_diffusion/__init__.py b/diffusers/pipelines/versatile_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d2caa7e2399001632b61504aa7bc59a1ad2bcfe
--- /dev/null
+++ b/diffusers/pipelines/versatile_diffusion/__init__.py
@@ -0,0 +1,16 @@
+from ...utils import is_torch_available, is_transformers_available, is_transformers_version
+
+
+if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
+ from .modeling_text_unet import UNetFlatConditionModel
+ from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
+ from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
+ from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
+ from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
+else:
+ from ...utils.dummy_torch_and_transformers_objects import (
+ VersatileDiffusionDualGuidedPipeline,
+ VersatileDiffusionImageVariationPipeline,
+ VersatileDiffusionPipeline,
+ VersatileDiffusionTextToImagePipeline,
+ )
diff --git a/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1a3d4c55e9199f448ccc820dc459131838ff299
--- /dev/null
+++ b/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
@@ -0,0 +1,1118 @@
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...modeling_utils import ModelMixin
+from ...models.attention import DualTransformer2DModel, Transformer2DModel
+from ...models.embeddings import TimestepEmbedding, Timesteps
+from ...models.unet_2d_condition import UNet2DConditionOutput
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+):
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlockFlat":
+ return DownBlockFlat(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "CrossAttnDownBlockFlat":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat")
+ return CrossAttnDownBlockFlat(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ )
+ raise ValueError(f"{down_block_type} is not supported.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlockFlat":
+ return UpBlockFlat(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ )
+ elif up_block_type == "CrossAttnUpBlockFlat":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat")
+ return CrossAttnUpBlockFlat(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ )
+ raise ValueError(f"{up_block_type} is not supported.")
+
+
+# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
+class UNetFlatConditionModel(ModelMixin, ConfigMixin):
+ r"""
+ UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a
+ timestep and returns sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the models (such as downloading or saving, etc.)
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockFlat",
+ "CrossAttnDownBlockFlat",
+ "CrossAttnDownBlockFlat",
+ "DownBlockFlat",
+ ),
+ up_block_types: Tuple[str] = (
+ "UpBlockFlat",
+ "CrossAttnUpBlockFlat",
+ "CrossAttnUpBlockFlat",
+ "CrossAttnUpBlockFlat",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ num_class_embeds: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = LinearMultiDim(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ # class embedding
+ if num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlockFlatCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift="default",
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ def set_attention_slice(self, slice_size):
+ head_dims = self.config.attention_head_dim
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a common divisor of "
+ f"the number of heads used in cross_attention: {head_dims}"
+ )
+ if slice_size is not None and slice_size > min(head_dims):
+ raise ValueError(
+ f"slice_size {slice_size} has to be smaller or equal to "
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
+ )
+
+ for block in self.down_blocks:
+ if hasattr(block, "attentions") and block.attentions is not None:
+ block.set_attention_slice(slice_size)
+
+ self.mid_block.set_attention_slice(slice_size)
+
+ for block in self.up_blocks:
+ if hasattr(block, "attentions") and block.attentions is not None:
+ block.set_attention_slice(slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+
+ if self.config.num_class_embeds is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
+
+
+class LinearMultiDim(nn.Linear):
+ def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs):
+ in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features)
+ if out_features is None:
+ out_features = in_features
+ out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features)
+ self.in_features_multidim = in_features
+ self.out_features_multidim = out_features
+ super().__init__(np.array(in_features).prod(), np.array(out_features).prod())
+
+ def forward(self, input_tensor, *args, **kwargs):
+ shape = input_tensor.shape
+ n_dim = len(self.in_features_multidim)
+ input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features)
+ output_tensor = super().forward(input_tensor)
+ output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim)
+ return output_tensor
+
+
+class ResnetBlockFlat(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ time_embedding_norm="default",
+ use_in_shortcut=None,
+ second_dim=4,
+ **kwargs,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+
+ in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels)
+ self.in_channels_prod = np.array(in_channels).prod()
+ self.channels_multidim = in_channels
+
+ if out_channels is not None:
+ out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels)
+ out_channels_prod = np.array(out_channels).prod()
+ self.out_channels_multidim = out_channels
+ else:
+ out_channels_prod = self.in_channels_prod
+ self.out_channels_multidim = self.channels_multidim
+ self.time_embedding_norm = time_embedding_norm
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True)
+ self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0)
+
+ if temb_channels is not None:
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0)
+
+ self.nonlinearity = nn.SiLU()
+
+ self.use_in_shortcut = (
+ self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut
+ )
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, input_tensor, temb):
+ shape = input_tensor.shape
+ n_dim = len(self.channels_multidim)
+ input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1)
+ input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1)
+
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = input_tensor + hidden_states
+
+ output_tensor = output_tensor.view(*shape[0:-n_dim], -1)
+ output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim)
+
+ return output_tensor
+
+
+# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
+class DownBlockFlat(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlockFlat(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ LinearMultiDim(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
+class CrossAttnDownBlockFlat(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlockFlat(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ LinearMultiDim(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def set_attention_slice(self, slice_size):
+ head_dims = self.attn_num_head_channels
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a common divisor of "
+ f"the number of heads used in cross_attention: {head_dims}"
+ )
+ if slice_size is not None and slice_size > min(head_dims):
+ raise ValueError(
+ f"slice_size {slice_size} has to be smaller or equal to "
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
+class UpBlockFlat(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlockFlat(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
+class CrossAttnUpBlockFlat(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlockFlat(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def set_attention_slice(self, slice_size):
+ head_dims = self.attn_num_head_channels
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a common divisor of "
+ f"the number of heads used in cross_attention: {head_dims}"
+ )
+ if slice_size is not None and slice_size > min(head_dims):
+ raise ValueError(
+ f"slice_size {slice_size} has to be smaller or equal to "
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
+class UNetMidBlockFlatCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlockFlat(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlockFlat(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def set_attention_slice(self, slice_size):
+ head_dims = self.attn_num_head_channels
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a common divisor of "
+ f"the number of heads used in cross_attention: {head_dims}"
+ )
+ if slice_size is not None and slice_size > min(head_dims):
+ raise ValueError(
+ f"slice_size {slice_size} has to be smaller or equal to "
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states).sample
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
diff --git a/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..7be7f4d3aee6df82daaf515bdd412c4b57b009c4
--- /dev/null
+++ b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py
@@ -0,0 +1,463 @@
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+
+import PIL.Image
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import logging
+from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
+from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
+from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class VersatileDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionMegaSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ tokenizer: CLIPTokenizer
+ image_feature_extractor: CLIPFeatureExtractor
+ text_encoder: CLIPTextModel
+ image_encoder: CLIPVisionModel
+ image_unet: UNet2DConditionModel
+ text_unet: UNet2DConditionModel
+ vae: AutoencoderKL
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+
+ def __init__(
+ self,
+ tokenizer: CLIPTokenizer,
+ image_feature_extractor: CLIPFeatureExtractor,
+ text_encoder: CLIPTextModel,
+ image_encoder: CLIPVisionModel,
+ image_unet: UNet2DConditionModel,
+ text_unet: UNet2DConditionModel,
+ vae: AutoencoderKL,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ image_feature_extractor=image_feature_extractor,
+ text_encoder=text_encoder,
+ image_encoder=image_encoder,
+ image_unet=image_unet,
+ text_unet=text_unet,
+ vae=vae,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.image_unet.config.attention_head_dim // 2
+ self.image_unet.set_attention_slice(slice_size)
+ self.text_unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ @torch.no_grad()
+ def image_variation(
+ self,
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
+ The image prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import VersatileDiffusionPipeline
+ >>> import torch
+ >>> import requests
+ >>> from io import BytesIO
+ >>> from PIL import Image
+
+ >>> # let's download an initial image
+ >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
+
+ >>> response = requests.get(url)
+ >>> image = Image.open(BytesIO(response.content)).convert("RGB")
+
+ >>> pipe = VersatileDiffusionPipeline.from_pretrained(
+ ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(0)
+ >>> image = pipe.image_variation(image, generator=generator).images[0]
+ >>> image.save("./car_variation.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys()
+ components = {name: component for name, component in self.components.items() if name in expected_components}
+ return VersatileDiffusionImageVariationPipeline(**components)(
+ image=image,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
+
+ @torch.no_grad()
+ def text_to_image(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import VersatileDiffusionPipeline
+ >>> import torch
+
+ >>> pipe = VersatileDiffusionPipeline.from_pretrained(
+ ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(0)
+ >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0]
+ >>> image.save("./astronaut.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys()
+ components = {name: component for name, component in self.components.items() if name in expected_components}
+ temp_pipeline = VersatileDiffusionTextToImagePipeline(**components)
+ output = temp_pipeline(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
+ # swap the attention blocks back to the original state
+ temp_pipeline._swap_unet_attention_blocks()
+
+ return output
+
+ @torch.no_grad()
+ def dual_guided(
+ self,
+ prompt: Union[PIL.Image.Image, List[PIL.Image.Image]],
+ image: Union[str, List[str]],
+ text_to_image_strength: float = 0.5,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import VersatileDiffusionPipeline
+ >>> import torch
+ >>> import requests
+ >>> from io import BytesIO
+ >>> from PIL import Image
+
+ >>> # let's download an initial image
+ >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
+
+ >>> response = requests.get(url)
+ >>> image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> text = "a red car in the sun"
+
+ >>> pipe = VersatileDiffusionPipeline.from_pretrained(
+ ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(0)
+ >>> text_to_image_strength = 0.75
+
+ >>> image = pipe.dual_guided(
+ ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator
+ ... ).images[0]
+ >>> image.save("./car_variation.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys()
+ components = {name: component for name, component in self.components.items() if name in expected_components}
+ temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components)
+ output = temp_pipeline(
+ prompt=prompt,
+ image=image,
+ text_to_image_strength=text_to_image_strength,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
+ temp_pipeline._revert_dual_attention()
+
+ return output
diff --git a/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a90ae2c7620aae03d32604eefae7b8f1f2b028f
--- /dev/null
+++ b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
@@ -0,0 +1,621 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+
+import PIL
+from transformers import (
+ CLIPFeatureExtractor,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models.attention import DualTransformer2DModel, Transformer2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import is_accelerate_available, logging
+from .modeling_text_unet import UNetFlatConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ bert ([`LDMBertModel`]):
+ Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
+ tokenizer (`transformers.BertTokenizer`):
+ Tokenizer of class
+ [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+ tokenizer: CLIPTokenizer
+ image_feature_extractor: CLIPFeatureExtractor
+ text_encoder: CLIPTextModelWithProjection
+ image_encoder: CLIPVisionModelWithProjection
+ image_unet: UNet2DConditionModel
+ text_unet: UNetFlatConditionModel
+ vae: AutoencoderKL
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+
+ _optional_components = ["text_unet"]
+
+ def __init__(
+ self,
+ tokenizer: CLIPTokenizer,
+ image_feature_extractor: CLIPFeatureExtractor,
+ text_encoder: CLIPTextModelWithProjection,
+ image_encoder: CLIPVisionModelWithProjection,
+ image_unet: UNet2DConditionModel,
+ text_unet: UNetFlatConditionModel,
+ vae: AutoencoderKL,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ ):
+ super().__init__()
+ self.register_modules(
+ tokenizer=tokenizer,
+ image_feature_extractor=image_feature_extractor,
+ text_encoder=text_encoder,
+ image_encoder=image_encoder,
+ image_unet=image_unet,
+ text_unet=text_unet,
+ vae=vae,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ if self.text_unet is not None and (
+ "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention
+ ):
+ # if loading from a universal checkpoint rather than a saved dual-guided pipeline
+ self._convert_to_dual_attention()
+
+ def remove_unused_weights(self):
+ self.register_modules(text_unet=None)
+
+ def _convert_to_dual_attention(self):
+ """
+ Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks
+ from both `image_unet` and `text_unet`
+ """
+ for name, module in self.image_unet.named_modules():
+ if isinstance(module, Transformer2DModel):
+ parent_name, index = name.rsplit(".", 1)
+ index = int(index)
+
+ image_transformer = self.image_unet.get_submodule(parent_name)[index]
+ text_transformer = self.text_unet.get_submodule(parent_name)[index]
+
+ config = image_transformer.config
+ dual_transformer = DualTransformer2DModel(
+ num_attention_heads=config.num_attention_heads,
+ attention_head_dim=config.attention_head_dim,
+ in_channels=config.in_channels,
+ num_layers=config.num_layers,
+ dropout=config.dropout,
+ norm_num_groups=config.norm_num_groups,
+ cross_attention_dim=config.cross_attention_dim,
+ attention_bias=config.attention_bias,
+ sample_size=config.sample_size,
+ num_vector_embeds=config.num_vector_embeds,
+ activation_fn=config.activation_fn,
+ num_embeds_ada_norm=config.num_embeds_ada_norm,
+ )
+ dual_transformer.transformers[0] = image_transformer
+ dual_transformer.transformers[1] = text_transformer
+
+ self.image_unet.get_submodule(parent_name)[index] = dual_transformer
+ self.image_unet.register_to_config(dual_cross_attention=True)
+
+ def _revert_dual_attention(self):
+ """
+ Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call
+ this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline`
+ """
+ for name, module in self.image_unet.named_modules():
+ if isinstance(module, DualTransformer2DModel):
+ parent_name, index = name.rsplit(".", 1)
+ index = int(index)
+ self.image_unet.get_submodule(parent_name)[index] = module.transformers[0]
+
+ self.image_unet.register_to_config(dual_cross_attention=False)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ if isinstance(self.image_unet.config.attention_head_dim, int):
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.image_unet.config.attention_head_dim // 2
+ else:
+ # if `attention_head_dim` is a list, take the smallest head size
+ slice_size = min(self.image_unet.config.attention_head_dim)
+
+ self.image_unet.set_attention_slice(slice_size)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"):
+ return self.device
+ for module in self.image_unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ """
+
+ def normalize_embeddings(encoder_output):
+ embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state)
+ embeds_pooled = encoder_output.text_embeds
+ embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
+ return embeds
+
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = normalize_embeddings(text_embeddings)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = normalize_embeddings(uncond_embeddings)
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ """
+
+ def normalize_embeddings(encoder_output):
+ embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state)
+ embeds = self.image_encoder.visual_projection(embeds)
+ embeds_pooled = embeds[:, 0:1]
+ embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
+ return embeds
+
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ image_input = self.image_feature_extractor(images=prompt, return_tensors="pt")
+ pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype)
+ image_embeddings = self.image_encoder(pixel_values)
+ image_embeddings = normalize_embeddings(image_embeddings)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size
+ uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
+ pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
+ uncond_embeddings = self.image_encoder(pixel_values)
+ uncond_embeddings = normalize_embeddings(uncond_embeddings)
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and conditional embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
+
+ return image_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, image, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}")
+ if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list):
+ raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")):
+ for name, module in self.image_unet.named_modules():
+ if isinstance(module, DualTransformer2DModel):
+ module.mix_ratio = mix_ratio
+
+ for i, type in enumerate(condition_types):
+ if type == "text":
+ module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings
+ module.transformer_index_for_condition[i] = 1 # use the second (text) transformer
+ else:
+ module.condition_lengths[i] = 257
+ module.transformer_index_for_condition[i] = 0 # use the first (image) transformer
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[PIL.Image.Image, List[PIL.Image.Image]],
+ image: Union[str, List[str]],
+ text_to_image_strength: float = 0.5,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import VersatileDiffusionDualGuidedPipeline
+ >>> import torch
+ >>> import requests
+ >>> from io import BytesIO
+ >>> from PIL import Image
+
+ >>> # let's download an initial image
+ >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
+
+ >>> response = requests.get(url)
+ >>> image = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> text = "a red car in the sun"
+
+ >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(
+ ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
+ ... )
+ >>> pipe.remove_unused_weights()
+ >>> pipe = pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(0)
+ >>> text_to_image_strength = 0.75
+
+ >>> image = pipe(
+ ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator
+ ... ).images[0]
+ >>> image.save("./car_variation.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.image_unet.config.sample_size * self.vae_scale_factor
+ width = width or self.image_unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, image, height, width, callback_steps)
+
+ # 2. Define call parameters
+ prompt = [prompt] if not isinstance(prompt, list) else prompt
+ image = [image] if not isinstance(image, list) else image
+ batch_size = len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompts
+ text_embeddings = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
+ image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance)
+ dual_prompt_embeddings = torch.cat([text_embeddings, image_embeddings], dim=1)
+ prompt_types = ("text", "image")
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.image_unet.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ dual_prompt_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Combine the attention blocks of the image and text UNets
+ self.set_transformer_params(text_to_image_strength, prompt_types)
+
+ # 8. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 9. Post-processing
+ image = self.decode_latents(latents)
+
+ # 10. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b68dd244ce47d8e0d5528b813d5e14d398e2bf1a
--- /dev/null
+++ b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
@@ -0,0 +1,451 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+
+import PIL
+from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import is_accelerate_available, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ bert ([`LDMBertModel`]):
+ Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
+ tokenizer (`transformers.BertTokenizer`):
+ Tokenizer of class
+ [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+ image_feature_extractor: CLIPFeatureExtractor
+ image_encoder: CLIPVisionModelWithProjection
+ image_unet: UNet2DConditionModel
+ vae: AutoencoderKL
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+
+ def __init__(
+ self,
+ image_feature_extractor: CLIPFeatureExtractor,
+ image_encoder: CLIPVisionModelWithProjection,
+ image_unet: UNet2DConditionModel,
+ vae: AutoencoderKL,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ ):
+ super().__init__()
+ self.register_modules(
+ image_feature_extractor=image_feature_extractor,
+ image_encoder=image_encoder,
+ image_unet=image_unet,
+ vae=vae,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ if isinstance(self.image_unet.config.attention_head_dim, int):
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.image_unet.config.attention_head_dim // 2
+ else:
+ # if `attention_head_dim` is a list, take the smallest head size
+ slice_size = min(self.image_unet.config.attention_head_dim)
+
+ self.image_unet.set_attention_slice(slice_size)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"):
+ return self.device
+ for module in self.image_unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+
+ def normalize_embeddings(encoder_output):
+ embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state)
+ embeds = self.image_encoder.visual_projection(embeds)
+ embeds_pooled = embeds[:, 0:1]
+ embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
+ return embeds
+
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ image_input = self.image_feature_extractor(images=prompt, return_tensors="pt")
+ pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype)
+ image_embeddings = self.image_encoder(pixel_values)
+ image_embeddings = normalize_embeddings(image_embeddings)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_images: List[str]
+ if negative_prompt is None:
+ uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, PIL.Image.Image):
+ uncond_images = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_images = negative_prompt
+
+ uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
+ pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
+ uncond_embeddings = self.image_encoder(pixel_values)
+ uncond_embeddings = normalize_embeddings(uncond_embeddings)
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and conditional embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
+
+ return image_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, image, height, width, callback_steps):
+ if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor):
+ raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
+ The image prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import VersatileDiffusionImageVariationPipeline
+ >>> import torch
+ >>> import requests
+ >>> from io import BytesIO
+ >>> from PIL import Image
+
+ >>> # let's download an initial image
+ >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
+
+ >>> response = requests.get(url)
+ >>> image = Image.open(BytesIO(response.content)).convert("RGB")
+
+ >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained(
+ ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(0)
+ >>> image = pipe(image, generator=generator).images[0]
+ >>> image.save("./car_variation.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.image_unet.config.sample_size * self.vae_scale_factor
+ width = width or self.image_unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(image, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ image_embeddings = self._encode_prompt(
+ image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.image_unet.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9c4bb7dc40e1cb207ce591d65440efe88adc1dd
--- /dev/null
+++ b/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
@@ -0,0 +1,505 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+import torch.utils.checkpoint
+
+from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...models.attention import Transformer2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import is_accelerate_available, logging
+from .modeling_text_unet import UNetFlatConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ bert ([`LDMBertModel`]):
+ Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
+ tokenizer (`transformers.BertTokenizer`):
+ Tokenizer of class
+ [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+ tokenizer: CLIPTokenizer
+ image_feature_extractor: CLIPFeatureExtractor
+ text_encoder: CLIPTextModelWithProjection
+ image_unet: UNet2DConditionModel
+ text_unet: UNetFlatConditionModel
+ vae: AutoencoderKL
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+
+ _optional_components = ["text_unet"]
+
+ def __init__(
+ self,
+ tokenizer: CLIPTokenizer,
+ text_encoder: CLIPTextModelWithProjection,
+ image_unet: UNet2DConditionModel,
+ text_unet: UNetFlatConditionModel,
+ vae: AutoencoderKL,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ ):
+ super().__init__()
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ image_unet=image_unet,
+ text_unet=text_unet,
+ vae=vae,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ if self.text_unet is not None:
+ self._swap_unet_attention_blocks()
+
+ def _swap_unet_attention_blocks(self):
+ """
+ Swap the `Transformer2DModel` blocks between the image and text UNets
+ """
+ for name, module in self.image_unet.named_modules():
+ if isinstance(module, Transformer2DModel):
+ parent_name, index = name.rsplit(".", 1)
+ index = int(index)
+ self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = (
+ self.text_unet.get_submodule(parent_name)[index],
+ self.image_unet.get_submodule(parent_name)[index],
+ )
+
+ def remove_unused_weights(self):
+ self.register_modules(text_unet=None)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ if isinstance(self.image_unet.config.attention_head_dim, int):
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.image_unet.config.attention_head_dim // 2
+ else:
+ # if `attention_head_dim` is a list, take the smallest head size
+ slice_size = min(self.image_unet.config.attention_head_dim)
+
+ self.image_unet.set_attention_slice(slice_size)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"):
+ return self.device
+ for module in self.image_unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+
+ def normalize_embeddings(encoder_output):
+ embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state)
+ embeds_pooled = encoder_output.text_embeds
+ embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
+ return embeds
+
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = normalize_embeddings(text_embeddings)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = normalize_embeddings(uncond_embeddings)
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import VersatileDiffusionTextToImagePipeline
+ >>> import torch
+
+ >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained(
+ ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
+ ... )
+ >>> pipe.remove_unused_weights()
+ >>> pipe = pipe.to("cuda")
+
+ >>> generator = torch.Generator(device="cuda").manual_seed(0)
+ >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0]
+ >>> image.save("./astronaut.png")
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.image_unet.config.sample_size * self.vae_scale_factor
+ width = width or self.image_unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.image_unet.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 9. Post-processing
+ image = self.decode_latents(latents)
+
+ # 10. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/vq_diffusion/__init__.py b/diffusers/pipelines/vq_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c9f14f000648347fe75a5bec0cb45d08c7d2ff9
--- /dev/null
+++ b/diffusers/pipelines/vq_diffusion/__init__.py
@@ -0,0 +1,5 @@
+from ...utils import is_torch_available, is_transformers_available
+
+
+if is_transformers_available() and is_torch_available():
+ from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline
diff --git a/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..333599d7ecf8b68827bdde55a37fa96c213c013a
--- /dev/null
+++ b/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
@@ -0,0 +1,335 @@
+# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+
+from diffusers import Transformer2DModel, VQModel
+from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...modeling_utils import ModelMixin
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin):
+ """
+ Utility class for storing learned text embeddings for classifier free sampling
+ """
+
+ @register_to_config
+ def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None):
+ super().__init__()
+
+ self.learnable = learnable
+
+ if self.learnable:
+ assert hidden_size is not None, "learnable=True requires `hidden_size` to be set"
+ assert length is not None, "learnable=True requires `length` to be set"
+
+ embeddings = torch.zeros(length, hidden_size)
+ else:
+ embeddings = None
+
+ self.embeddings = torch.nn.Parameter(embeddings)
+
+
+class VQDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using VQ Diffusion
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vqvae ([`VQModel`]):
+ Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent
+ representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. VQ Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ transformer ([`Transformer2DModel`]):
+ Conditional transformer to denoise the encoded image latents.
+ scheduler ([`VQDiffusionScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+
+ vqvae: VQModel
+ text_encoder: CLIPTextModel
+ tokenizer: CLIPTokenizer
+ transformer: Transformer2DModel
+ learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings
+ scheduler: VQDiffusionScheduler
+
+ def __init__(
+ self,
+ vqvae: VQModel,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ transformer: Transformer2DModel,
+ scheduler: VQDiffusionScheduler,
+ learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vqvae=vqvae,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,
+ )
+
+ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
+ # While CLIP does normalize the pooled output of the text transformer when combining
+ # the image and text embeddings, CLIP does not directly normalize the last hidden state.
+ #
+ # CLIP normalizing the pooled output.
+ # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
+ text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
+
+ # duplicate text embeddings for each generation per prompt
+ text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ if self.learned_classifier_free_sampling_embeddings.learnable:
+ uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings
+ uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)
+ else:
+ uncond_tokens = [""] * batch_size
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+ # See comment for normalizing text embeddings
+ uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True)
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ num_inference_steps: int = 100,
+ guidance_scale: float = 5.0,
+ truncation_rate: float = 1.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
+ Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
+ most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
+ `truncation_rate` are set to zero.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor` of shape (batch), *optional*):
+ Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices.
+ Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will
+ be generated of completely masked latent pixels.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # get the initial completely masked latents unless the user supplied it
+
+ latents_shape = (batch_size, self.transformer.num_latent_pixels)
+ if latents is None:
+ mask_class = self.transformer.num_vector_embeds - 1
+ latents = torch.full(latents_shape, mask_class).to(self.device)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any():
+ raise ValueError(
+ "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0,"
+ f" {self.transformer.num_vector_embeds - 1} (inclusive)."
+ )
+ latents = latents.to(self.device)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
+
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ sample = latents
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the sample if we are doing classifier free guidance
+ latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample
+
+ # predict the un-noised image
+ # model_output == `log_p_x_0`
+ model_output = self.transformer(
+ latent_model_input, encoder_hidden_states=text_embeddings, timestep=t
+ ).sample
+
+ if do_classifier_free_guidance:
+ model_output_uncond, model_output_text = model_output.chunk(2)
+ model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond)
+ model_output -= torch.logsumexp(model_output, dim=1, keepdim=True)
+
+ model_output = self.truncate(model_output, truncation_rate)
+
+ # remove `log(0)`'s (`-inf`s)
+ model_output = model_output.clamp(-70)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, sample)
+
+ embedding_channels = self.vqvae.config.vq_embed_dim
+ embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels)
+ embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape)
+ image = self.vqvae.decode(embeddings, force_not_quantize=True).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
+
+ def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor:
+ """
+ Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The
+ lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero.
+ """
+ sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True)
+ sorted_p_x_0 = torch.exp(sorted_log_p_x_0)
+ keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate
+
+ # Ensure that at least the largest probability is not zeroed out
+ all_true = torch.full_like(keep_mask[:, 0:1, :], True)
+ keep_mask = torch.cat((all_true, keep_mask), dim=1)
+ keep_mask = keep_mask[:, :-1, :]
+
+ keep_mask = keep_mask.gather(1, indices.argsort(1))
+
+ rv = log_p_x_0.clone()
+
+ rv[~keep_mask] = -torch.inf # -inf = log(0)
+
+ return rv
diff --git a/diffusers/schedulers/README.md b/diffusers/schedulers/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9494e357fd43465d5a1aa4da1bf784e1fcc40039
--- /dev/null
+++ b/diffusers/schedulers/README.md
@@ -0,0 +1,3 @@
+# Schedulers
+
+For more information on the schedulers, please refer to the [docs](https://huggingface.co/docs/diffusers/api/schedulers).
\ No newline at end of file
diff --git a/diffusers/schedulers/__init__.py b/diffusers/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3069f1eed2800c55b007ede0bd2770c1a115653
--- /dev/null
+++ b/diffusers/schedulers/__init__.py
@@ -0,0 +1,55 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ..utils import is_flax_available, is_scipy_available, is_torch_available
+
+
+if is_torch_available():
+ from .scheduling_ddim import DDIMScheduler
+ from .scheduling_ddpm import DDPMScheduler
+ from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
+ from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
+ from .scheduling_euler_discrete import EulerDiscreteScheduler
+ from .scheduling_heun_discrete import HeunDiscreteScheduler
+ from .scheduling_ipndm import IPNDMScheduler
+ from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
+ from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
+ from .scheduling_karras_ve import KarrasVeScheduler
+ from .scheduling_pndm import PNDMScheduler
+ from .scheduling_repaint import RePaintScheduler
+ from .scheduling_sde_ve import ScoreSdeVeScheduler
+ from .scheduling_sde_vp import ScoreSdeVpScheduler
+ from .scheduling_utils import SchedulerMixin
+ from .scheduling_vq_diffusion import VQDiffusionScheduler
+else:
+ from ..utils.dummy_pt_objects import * # noqa F403
+
+if is_flax_available():
+ from .scheduling_ddim_flax import FlaxDDIMScheduler
+ from .scheduling_ddpm_flax import FlaxDDPMScheduler
+ from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
+ from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
+ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
+ from .scheduling_pndm_flax import FlaxPNDMScheduler
+ from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
+ from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
+else:
+ from ..utils.dummy_flax_objects import * # noqa F403
+
+
+if is_scipy_available() and is_torch_available():
+ from .scheduling_lms_discrete import LMSDiscreteScheduler
+else:
+ from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
diff --git a/diffusers/schedulers/scheduling_ddim.py b/diffusers/schedulers/scheduling_ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd38bd63c27d9665d6e4d97ea5068eb40ec44c5f
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddim.py
@@ -0,0 +1,380 @@
+# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
+class DDIMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return torch.tensor(betas)
+
+
+class DDIMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ set_alpha_to_one (`bool`, default `True`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ """
+
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ _deprecated_kwargs = ["predict_epsilon"]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ **kwargs,
+ ):
+ message = (
+ "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
+ " DDIMScheduler.from_pretrained(, prediction_type='epsilon')`."
+ )
+ predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
+ if predict_epsilon is not None:
+ self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
+
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ self.num_inference_steps = num_inference_steps
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+ self.timesteps += self.config.steps_offset
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.FloatTensor] = None,
+ return_dict: bool = True,
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
+ generator: random number generator.
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
+ can directly provide the noise for the variance itself. This is useful for methods such as
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ # predict V
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ # 4. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ if use_clipped_model_output:
+ # the model_output is always re-derived from the clipped x_0 in Glide
+ model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
+ device = model_output.device
+ if variance_noise is not None and generator is not None:
+ raise ValueError(
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
+ " `variance_noise` stays `None`."
+ )
+
+ if variance_noise is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
+ variance_noise = variance_noise.to(device)
+ else:
+ variance_noise = torch.randn(
+ model_output.shape, generator=generator, device=device, dtype=model_output.dtype
+ )
+ variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
+
+ prev_sample = prev_sample + variance
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def get_velocity(
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_ddim_flax.py b/diffusers/schedulers/scheduling_ddim_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..157321d4681639c865e77745f9513b9a9a43b466
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddim_flax.py
@@ -0,0 +1,326 @@
+# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import flax
+import jax.numpy as jnp
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import deprecate
+from .scheduling_utils_flax import (
+ _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ broadcast_to_shape_from_left,
+)
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return jnp.array(betas, dtype=jnp.float32)
+
+
+@flax.struct.dataclass
+class DDIMSchedulerState:
+ # setable values
+ timesteps: jnp.ndarray
+ alphas_cumprod: jnp.ndarray
+ num_inference_steps: Optional[int] = None
+
+ @classmethod
+ def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):
+ return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod)
+
+
+@dataclass
+class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput):
+ state: DDIMSchedulerState
+
+
+class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`jnp.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ set_alpha_to_one (`bool`, default `True`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ prediction_type (`str`, default `epsilon`):
+ indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
+ `v-prediction` is not supported for this scheduler.
+
+ """
+
+ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ _deprecated_kwargs = ["predict_epsilon"]
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ **kwargs,
+ ):
+ message = (
+ "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
+ " FlaxDDIMScheduler.from_pretrained(, prediction_type='epsilon')`."
+ )
+ predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
+ if predict_epsilon is not None:
+ self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
+
+ if beta_schedule == "linear":
+ self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+
+ # HACK for now - clean up later (PVP)
+ self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ def scale_model_input(
+ self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
+ ) -> jnp.ndarray:
+ """
+ Args:
+ state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
+ sample (`jnp.ndarray`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `jnp.ndarray`: scaled input sample
+ """
+ return sample
+
+ def create_state(self):
+ return DDIMSchedulerState.create(
+ num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
+ )
+
+ def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
+ alpha_prod_t = alphas_cumprod[timestep]
+ alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def set_timesteps(
+ self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
+ ) -> DDIMSchedulerState:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`DDIMSchedulerState`):
+ the `FlaxDDIMScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ offset = self.config.steps_offset
+
+ step_ratio = self.config.num_train_timesteps // num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
+ timesteps = timesteps + offset
+
+ return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
+
+ def step(
+ self,
+ state: DDIMSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ return_dict: bool = True,
+ ) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
+
+ Returns:
+ [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if state.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
+ eta = 0.0
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
+
+ alphas_cumprod = state.alphas_cumprod
+
+ # 2. compute alphas, betas
+ alpha_prod_t = alphas_cumprod[timestep]
+ alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ # predict V
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ # 4. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
+ std_dev_t = eta * variance ** (0.5)
+
+ # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
+
+ # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if not return_dict:
+ return (prev_sample, state)
+
+ return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state)
+
+ def add_noise(
+ self,
+ original_samples: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
+
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_ddpm.py b/diffusers/schedulers/scheduling_ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..369db8b29e7d2e9abb9707bbf877ba7707f664eb
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddpm.py
@@ -0,0 +1,373 @@
+# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+class DDPMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class DDPMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
+ Langevin dynamics sampling.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2006.11239
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ variance_type (`str`):
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ """
+
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ _deprecated_kwargs = ["predict_epsilon"]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ variance_type: str = "fixed_small",
+ clip_sample: bool = True,
+ prediction_type: str = "epsilon",
+ **kwargs,
+ ):
+ message = (
+ "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
+ " DDPMScheduler.from_pretrained(, prediction_type='epsilon')`."
+ )
+ predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
+ if predict_epsilon is not None:
+ self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
+
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ elif beta_schedule == "sigmoid":
+ # GeoDiff sigmoid schedule
+ betas = torch.linspace(-6, 6, num_train_timesteps)
+ self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ self.one = torch.tensor(1.0)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
+
+ self.variance_type = variance_type
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
+ self.num_inference_steps = num_inference_steps
+ timesteps = np.arange(
+ 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
+ )[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def _get_variance(self, t, predicted_variance=None, variance_type=None):
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # and sample from it to get previous sample
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
+
+ if variance_type is None:
+ variance_type = self.config.variance_type
+
+ # hacks - were probably added for training stability
+ if variance_type == "fixed_small":
+ variance = torch.clamp(variance, min=1e-20)
+ # for rl-diffuser https://arxiv.org/abs/2205.09991
+ elif variance_type == "fixed_small_log":
+ variance = torch.log(torch.clamp(variance, min=1e-20))
+ variance = torch.exp(0.5 * variance)
+ elif variance_type == "fixed_large":
+ variance = self.betas[t]
+ elif variance_type == "fixed_large_log":
+ # Glide max_log
+ variance = torch.log(self.betas[t])
+ elif variance_type == "learned":
+ return predicted_variance
+ elif variance_type == "learned_range":
+ min_log = variance
+ max_log = self.betas[t]
+ frac = (predicted_variance + 1) / 2
+ variance = frac * max_log + (1 - frac) * min_log
+
+ return variance
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ generator=None,
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[DDPMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ message = (
+ "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
+ " DDPMScheduler.from_pretrained(, prediction_type='epsilon')`."
+ )
+ predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
+ if predict_epsilon is not None:
+ new_config = dict(self.config)
+ new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
+ self._internal_dict = FrozenDict(new_config)
+
+ t = timestep
+
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
+ " `v_prediction` for the DDPMScheduler."
+ )
+
+ # 3. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
+ current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ # 6. Add noise
+ variance = 0
+ if t > 0:
+ device = model_output.device
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
+ variance_noise = variance_noise.to(device)
+ else:
+ variance_noise = torch.randn(
+ model_output.shape, generator=generator, device=device, dtype=model_output.dtype
+ )
+ if self.variance_type == "fixed_small_log":
+ variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
+ else:
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ if not return_dict:
+ return (pred_prev_sample,)
+
+ return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def get_velocity(
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_ddpm_flax.py b/diffusers/schedulers/scheduling_ddpm_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..e716ea0abaad045b86d902cb41362027092d7349
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddpm_flax.py
@@ -0,0 +1,318 @@
+# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import flax
+import jax.numpy as jnp
+from jax import random
+
+from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
+from ..utils import deprecate
+from .scheduling_utils_flax import (
+ _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ broadcast_to_shape_from_left,
+)
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return jnp.array(betas, dtype=jnp.float32)
+
+
+@flax.struct.dataclass
+class DDPMSchedulerState:
+ # setable values
+ timesteps: jnp.ndarray
+ num_inference_steps: Optional[int] = None
+
+ @classmethod
+ def create(cls, num_train_timesteps: int):
+ return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
+
+
+@dataclass
+class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput):
+ state: DDPMSchedulerState
+
+
+class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
+ Langevin dynamics sampling.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2006.11239
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ variance_type (`str`):
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ prediction_type (`str`, default `epsilon`):
+ indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
+ `v-prediction` is not supported for this scheduler.
+ """
+
+ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ _deprecated_kwargs = ["predict_epsilon"]
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[jnp.ndarray] = None,
+ variance_type: str = "fixed_small",
+ clip_sample: bool = True,
+ prediction_type: str = "epsilon",
+ **kwargs,
+ ):
+ message = (
+ "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
+ " FlaxDDPMScheduler.from_pretrained(, prediction_type='epsilon')`."
+ )
+ predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
+ if predict_epsilon is not None:
+ self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
+
+ if trained_betas is not None:
+ self.betas = jnp.asarray(trained_betas)
+ elif beta_schedule == "linear":
+ self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
+ self.one = jnp.array(1.0)
+
+ def create_state(self):
+ return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
+
+ def set_timesteps(
+ self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = ()
+ ) -> DDPMSchedulerState:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`DDIMSchedulerState`):
+ the `FlaxDDPMScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
+ timesteps = jnp.arange(
+ 0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps
+ )[::-1]
+ return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
+
+ def _get_variance(self, t, predicted_variance=None, variance_type=None):
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # and sample from it to get previous sample
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
+
+ if variance_type is None:
+ variance_type = self.config.variance_type
+
+ # hacks - were probably added for training stability
+ if variance_type == "fixed_small":
+ variance = jnp.clip(variance, a_min=1e-20)
+ # for rl-diffuser https://arxiv.org/abs/2205.09991
+ elif variance_type == "fixed_small_log":
+ variance = jnp.log(jnp.clip(variance, a_min=1e-20))
+ elif variance_type == "fixed_large":
+ variance = self.betas[t]
+ elif variance_type == "fixed_large_log":
+ # Glide max_log
+ variance = jnp.log(self.betas[t])
+ elif variance_type == "learned":
+ return predicted_variance
+ elif variance_type == "learned_range":
+ min_log = variance
+ max_log = self.betas[t]
+ frac = (predicted_variance + 1) / 2
+ variance = frac * max_log + (1 - frac) * min_log
+
+ return variance
+
+ def step(
+ self,
+ state: DDPMSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ key: random.KeyArray,
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ key (`random.KeyArray`): a PRNG key.
+ return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
+
+ Returns:
+ [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ message = (
+ "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
+ " FlaxDDPMScheduler.from_pretrained(, prediction_type='epsilon')`."
+ )
+ predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
+ if predict_epsilon is not None:
+ new_config = dict(self.config)
+ new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
+ self._internal_dict = FrozenDict(new_config)
+
+ t = timestep
+
+ if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
+ " for the FlaxDDPMScheduler."
+ )
+
+ # 3. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
+ current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ # 6. Add noise
+ variance = 0
+ if t > 0:
+ key = random.split(key, num=1)
+ noise = random.normal(key=key, shape=model_output.shape)
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ if not return_dict:
+ return (pred_prev_sample, state)
+
+ return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state)
+
+ def add_noise(
+ self,
+ original_samples: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
+
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/diffusers/schedulers/scheduling_dpmsolver_multistep.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7d0838026913b751128ad23925cef1e978fa906
--- /dev/null
+++ b/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -0,0 +1,533 @@
+# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
+ the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
+ samples, and it can generate quite good samples even in only 10 steps.
+
+ For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
+
+ Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
+ recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
+
+ We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
+ diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
+ thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ stable-diffusion).
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ solver_order (`int`, default `2`):
+ the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
+ sampling, and `solver_order=3` for unconditional sampling.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
+ use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
+ models (such as stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487).
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
+ `algorithm_type="dpmsolver++`.
+ algorithm_type (`str`, default `dpmsolver++`):
+ the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
+ algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
+ https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
+ sampling (e.g. stable-diffusion).
+ solver_type (`str`, default `midpoint`):
+ the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
+ the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
+ slightly better, so we recommend to use the `midpoint` type.
+ lower_order_final (`bool`, default `True`):
+ whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
+ find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
+
+ """
+
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ _deprecated_kwargs = ["predict_epsilon"]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ solver_order: int = 2,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ algorithm_type: str = "dpmsolver++",
+ solver_type: str = "midpoint",
+ lower_order_final: bool = True,
+ **kwargs,
+ ):
+ message = (
+ "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
+ " DPMSolverMultistepScheduler.from_pretrained(, prediction_type='epsilon')`."
+ )
+ predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
+ if predict_epsilon is not None:
+ self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
+
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ # Currently we only support VP-type noise schedule
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # settings for DPM-Solver
+ if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
+ raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
+ if solver_type not in ["midpoint", "heun"]:
+ raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
+
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.model_outputs = [None] * solver_order
+ self.lower_order_nums = 0
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+ timesteps = (
+ np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
+ .round()[::-1][:-1]
+ .copy()
+ .astype(np.int64)
+ )
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+
+ def convert_model_output(
+ self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
+ ) -> torch.FloatTensor:
+ """
+ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
+
+ DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
+ discretize an integral of the data prediction model. So we need to first convert the model output to the
+ corresponding type to match the algorithm.
+
+ Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
+ DPM-Solver++ for both noise prediction model and data prediction model.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the converted model output.
+ """
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
+ if self.config.algorithm_type == "dpmsolver++":
+ if self.config.prediction_type == "epsilon":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
+ elif self.config.prediction_type == "sample":
+ x0_pred = model_output
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = alpha_t * sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the DPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ # Dynamic thresholding in https://arxiv.org/abs/2205.11487
+ orig_dtype = x0_pred.dtype
+ if orig_dtype not in [torch.float, torch.double]:
+ x0_pred = x0_pred.float()
+ dynamic_max_val = torch.quantile(
+ torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
+ )
+ dynamic_max_val = torch.maximum(
+ dynamic_max_val,
+ self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
+ )[(...,) + (None,) * (x0_pred.ndim - 1)]
+ x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
+ x0_pred = x0_pred.type(orig_dtype)
+ return x0_pred
+ # DPM-Solver needs to solve an integral of the noise prediction model.
+ elif self.config.algorithm_type == "dpmsolver":
+ if self.config.prediction_type == "epsilon":
+ return model_output
+ elif self.config.prediction_type == "sample":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ epsilon = (sample - alpha_t * model_output) / sigma_t
+ return epsilon
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ epsilon = alpha_t * model_output + sigma_t * sample
+ return epsilon
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction` for the DPMSolverMultistepScheduler."
+ )
+
+ def dpm_solver_first_order_update(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ One step for the first-order DPM-Solver (equivalent to DDIM).
+
+ See https://arxiv.org/abs/2206.00927 for the detailed derivation.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
+ alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
+ sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
+ h = lambda_t - lambda_s
+ if self.config.algorithm_type == "dpmsolver++":
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
+ elif self.config.algorithm_type == "dpmsolver":
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ return x_t
+
+ def multistep_dpm_solver_second_order_update(
+ self,
+ model_output_list: List[torch.FloatTensor],
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ One step for the second-order multistep DPM-Solver.
+
+ Args:
+ model_output_list (`List[torch.FloatTensor]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
+ m0, m1 = model_output_list[-1], model_output_list[-2]
+ lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
+ r0 = h_0 / h
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_order_update(
+ self,
+ model_output_list: List[torch.FloatTensor],
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ One step for the third-order multistep DPM-Solver.
+
+ Args:
+ model_output_list (`List[torch.FloatTensor]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
+ lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
+ self.lambda_t[t],
+ self.lambda_t[s0],
+ self.lambda_t[s1],
+ self.lambda_t[s2],
+ )
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
+ r0, r1 = h_0 / h, h_1 / h
+ D0 = m0
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
+ )
+ return x_t
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the multistep DPM-Solver.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero()
+ if len(step_index) == 0:
+ step_index = len(self.timesteps) - 1
+ else:
+ step_index = step_index.item()
+ prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
+ lower_order_final = (
+ (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
+ )
+ lower_order_second = (
+ (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
+ )
+
+ model_output = self.convert_model_output(model_output, timestep, sample)
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.model_outputs[-1] = model_output
+
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
+ prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
+ timestep_list = [self.timesteps[step_index - 1], timestep]
+ prev_sample = self.multistep_dpm_solver_second_order_update(
+ self.model_outputs, timestep_list, prev_timestep, sample
+ )
+ else:
+ timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
+ prev_sample = self.multistep_dpm_solver_third_order_update(
+ self.model_outputs, timestep_list, prev_timestep, sample
+ )
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..a44070d1d2aa1b5964884f17f1cbf335b9433f8e
--- /dev/null
+++ b/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
@@ -0,0 +1,625 @@
+# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import flax
+import jax
+import jax.numpy as jnp
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import deprecate
+from .scheduling_utils_flax import (
+ _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ broadcast_to_shape_from_left,
+)
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return jnp.array(betas, dtype=jnp.float32)
+
+
+@flax.struct.dataclass
+class DPMSolverMultistepSchedulerState:
+ # setable values
+ num_inference_steps: Optional[int] = None
+ timesteps: Optional[jnp.ndarray] = None
+
+ # running values
+ model_outputs: Optional[jnp.ndarray] = None
+ lower_order_nums: Optional[int] = None
+ step_index: Optional[int] = None
+ prev_timestep: Optional[int] = None
+ cur_sample: Optional[jnp.ndarray] = None
+
+ @classmethod
+ def create(cls, num_train_timesteps: int):
+ return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
+
+
+@dataclass
+class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput):
+ state: DPMSolverMultistepSchedulerState
+
+
+class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
+ the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
+ samples, and it can generate quite good samples even in only 10 steps.
+
+ For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
+
+ Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
+ recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
+
+ We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
+ diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
+ thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ stable-diffusion).
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ solver_order (`int`, default `2`):
+ the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
+ sampling, and `solver_order=3` for unconditional sampling.
+ prediction_type (`str`, default `epsilon`):
+ indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
+ or `v-prediction`.
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
+ use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
+ models (such as stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487).
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
+ `algorithm_type="dpmsolver++`.
+ algorithm_type (`str`, default `dpmsolver++`):
+ the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
+ algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
+ https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
+ sampling (e.g. stable-diffusion).
+ solver_type (`str`, default `midpoint`):
+ the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
+ the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
+ slightly better, so we recommend to use the `midpoint` type.
+ lower_order_final (`bool`, default `True`):
+ whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
+ find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
+
+ """
+
+ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ _deprecated_kwargs = ["predict_epsilon"]
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[jnp.ndarray] = None,
+ solver_order: int = 2,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ algorithm_type: str = "dpmsolver++",
+ solver_type: str = "midpoint",
+ lower_order_final: bool = True,
+ **kwargs,
+ ):
+ message = (
+ "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
+ " FlaxDPMSolverMultistepScheduler.from_pretrained(, prediction_type='epsilon')`."
+ )
+ predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
+ if predict_epsilon is not None:
+ self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
+
+ if trained_betas is not None:
+ self.betas = jnp.asarray(trained_betas)
+ elif beta_schedule == "linear":
+ self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
+ # Currently we only support VP-type noise schedule
+ self.alpha_t = jnp.sqrt(self.alphas_cumprod)
+ self.sigma_t = jnp.sqrt(1 - self.alphas_cumprod)
+ self.lambda_t = jnp.log(self.alpha_t) - jnp.log(self.sigma_t)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # settings for DPM-Solver
+ if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
+ raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
+ if solver_type not in ["midpoint", "heun"]:
+ raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
+
+ def create_state(self):
+ return DPMSolverMultistepSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
+
+ def set_timesteps(
+ self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
+ ) -> DPMSolverMultistepSchedulerState:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`DPMSolverMultistepSchedulerState`):
+ the `FlaxDPMSolverMultistepScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ shape (`Tuple`):
+ the shape of the samples to be generated.
+ """
+ timesteps = (
+ jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
+ .round()[::-1][:-1]
+ .astype(jnp.int32)
+ )
+
+ return state.replace(
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps,
+ model_outputs=jnp.zeros((self.config.solver_order,) + shape),
+ lower_order_nums=0,
+ step_index=0,
+ prev_timestep=-1,
+ cur_sample=jnp.zeros(shape),
+ )
+
+ def convert_model_output(
+ self,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """
+ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
+
+ DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
+ discretize an integral of the data prediction model. So we need to first convert the model output to the
+ corresponding type to match the algorithm.
+
+ Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
+ DPM-Solver++ for both noise prediction model and data prediction model.
+
+ Args:
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `jnp.ndarray`: the converted model output.
+ """
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
+ if self.config.algorithm_type == "dpmsolver++":
+ if self.config.prediction_type == "epsilon":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
+ elif self.config.prediction_type == "sample":
+ x0_pred = model_output
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ x0_pred = alpha_t * sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ # Dynamic thresholding in https://arxiv.org/abs/2205.11487
+ dynamic_max_val = jnp.percentile(
+ jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
+ )
+ dynamic_max_val = jnp.maximum(
+ dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val)
+ )
+ x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
+ return x0_pred
+ # DPM-Solver needs to solve an integral of the noise prediction model.
+ elif self.config.algorithm_type == "dpmsolver":
+ if self.config.prediction_type == "epsilon":
+ return model_output
+ elif self.config.prediction_type == "sample":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ epsilon = (sample - alpha_t * model_output) / sigma_t
+ return epsilon
+ elif self.config.prediction_type == "v_prediction":
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
+ epsilon = alpha_t * model_output + sigma_t * sample
+ return epsilon
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
+ )
+
+ def dpm_solver_first_order_update(
+ self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray
+ ) -> jnp.ndarray:
+ """
+ One step for the first-order DPM-Solver (equivalent to DDIM).
+
+ See https://arxiv.org/abs/2206.00927 for the detailed derivation.
+
+ Args:
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `jnp.ndarray`: the sample tensor at the previous timestep.
+ """
+ t, s0 = prev_timestep, timestep
+ m0 = model_output
+ lambda_t, lambda_s = self.lambda_t[t], self.lambda_t[s0]
+ alpha_t, alpha_s = self.alpha_t[t], self.alpha_t[s0]
+ sigma_t, sigma_s = self.sigma_t[t], self.sigma_t[s0]
+ h = lambda_t - lambda_s
+ if self.config.algorithm_type == "dpmsolver++":
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0
+ elif self.config.algorithm_type == "dpmsolver":
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0
+ return x_t
+
+ def multistep_dpm_solver_second_order_update(
+ self,
+ model_output_list: jnp.ndarray,
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """
+ One step for the second-order multistep DPM-Solver.
+
+ Args:
+ model_output_list (`List[jnp.ndarray]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `jnp.ndarray`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
+ m0, m1 = model_output_list[-1], model_output_list[-2]
+ lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
+ r0 = h_0 / h
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (jnp.exp(-h) - 1.0)) * D0
+ - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (jnp.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (jnp.exp(h) - 1.0)) * D0
+ - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (jnp.exp(h) - 1.0)) * D0
+ - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_order_update(
+ self,
+ model_output_list: jnp.ndarray,
+ timestep_list: List[int],
+ prev_timestep: int,
+ sample: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """
+ One step for the third-order multistep DPM-Solver.
+
+ Args:
+ model_output_list (`List[jnp.ndarray]`):
+ direct outputs from learned diffusion model at current and latter timesteps.
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+
+ Returns:
+ `jnp.ndarray`: the sample tensor at the previous timestep.
+ """
+ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
+ lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
+ self.lambda_t[t],
+ self.lambda_t[s0],
+ self.lambda_t[s1],
+ self.lambda_t[s2],
+ )
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
+ r0, r1 = h_0 / h, h_1 / h
+ D0 = m0
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
+ if self.config.algorithm_type == "dpmsolver++":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = (
+ (sigma_t / sigma_s0) * sample
+ - (alpha_t * (jnp.exp(-h) - 1.0)) * D0
+ + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
+ - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
+ )
+ elif self.config.algorithm_type == "dpmsolver":
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ x_t = (
+ (alpha_t / alpha_s0) * sample
+ - (sigma_t * (jnp.exp(h) - 1.0)) * D0
+ - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
+ - (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
+ )
+ return x_t
+
+ def step(
+ self,
+ state: DPMSolverMultistepSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ return_dict: bool = True,
+ ) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process
+ from the learned model outputs (most often the predicted noise).
+
+ Args:
+ state (`DPMSolverMultistepSchedulerState`):
+ the `FlaxDPMSolverMultistepScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class
+
+ Returns:
+ [`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ prev_timestep = jax.lax.cond(
+ state.step_index == len(state.timesteps) - 1,
+ lambda _: 0,
+ lambda _: state.timesteps[state.step_index + 1],
+ (),
+ )
+
+ model_output = self.convert_model_output(model_output, timestep, sample)
+
+ model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0)
+ model_outputs_new = model_outputs_new.at[-1].set(model_output)
+ state = state.replace(
+ model_outputs=model_outputs_new,
+ prev_timestep=prev_timestep,
+ cur_sample=sample,
+ )
+
+ def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
+ return self.dpm_solver_first_order_update(
+ state.model_outputs[-1],
+ state.timesteps[state.step_index],
+ state.prev_timestep,
+ state.cur_sample,
+ )
+
+ def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
+ def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
+ timestep_list = jnp.array([state.timesteps[state.step_index - 1], state.timesteps[state.step_index]])
+ return self.multistep_dpm_solver_second_order_update(
+ state.model_outputs,
+ timestep_list,
+ state.prev_timestep,
+ state.cur_sample,
+ )
+
+ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
+ timestep_list = jnp.array(
+ [
+ state.timesteps[state.step_index - 2],
+ state.timesteps[state.step_index - 1],
+ state.timesteps[state.step_index],
+ ]
+ )
+ return self.multistep_dpm_solver_third_order_update(
+ state.model_outputs,
+ timestep_list,
+ state.prev_timestep,
+ state.cur_sample,
+ )
+
+ if self.config.solver_order == 2:
+ return step_2(state)
+ elif self.config.lower_order_final and len(state.timesteps) < 15:
+ return jax.lax.cond(
+ state.lower_order_nums < 2,
+ step_2,
+ lambda state: jax.lax.cond(
+ state.step_index == len(state.timesteps) - 2,
+ step_2,
+ step_3,
+ state,
+ ),
+ state,
+ )
+ else:
+ return jax.lax.cond(
+ state.lower_order_nums < 2,
+ step_2,
+ step_3,
+ state,
+ )
+
+ if self.config.solver_order == 1:
+ prev_sample = step_1(state)
+ elif self.config.lower_order_final and len(state.timesteps) < 15:
+ prev_sample = jax.lax.cond(
+ state.lower_order_nums < 1,
+ step_1,
+ lambda state: jax.lax.cond(
+ state.step_index == len(state.timesteps) - 1,
+ step_1,
+ step_23,
+ state,
+ ),
+ state,
+ )
+ else:
+ prev_sample = jax.lax.cond(
+ state.lower_order_nums < 1,
+ step_1,
+ step_23,
+ state,
+ )
+
+ state = state.replace(
+ lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order),
+ step_index=(state.step_index + 1),
+ )
+
+ if not return_dict:
+ return (prev_sample, state)
+
+ return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state)
+
+ def scale_model_input(
+ self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
+ ) -> jnp.ndarray:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ state (`DPMSolverMultistepSchedulerState`):
+ the `FlaxDPMSolverMultistepScheduler` state data class instance.
+ sample (`jnp.ndarray`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `jnp.ndarray`: scaled input sample
+ """
+ return sample
+
+ def add_noise(
+ self,
+ original_samples: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
+
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5905a3f83641979de0679331bfc51bb2aa7cd50
--- /dev/null
+++ b/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
@@ -0,0 +1,279 @@
+# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
+from .scheduling_utils import SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
+class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+
+ """
+
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = self.sigmas.max()
+
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.is_scale_input_called = False
+
+ def scale_model_input(
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
+ ) -> torch.FloatTensor:
+ """
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ self.is_scale_input_called = True
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
+ else:
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`float`): current timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ generator (`torch.Generator`, optional): Random number generator.
+ return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
+ a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep.",
+ )
+
+ if not self.is_scale_input_called:
+ logger.warning(
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
+ "See `StableDiffusionPipeline` for a usage example."
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = sample - sigma * model_output
+ elif self.config.prediction_type == "v_prediction":
+ # * c_out + input * c_skip
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ sigma_from = self.sigmas[step_index]
+ sigma_to = self.sigmas[step_index + 1]
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma
+
+ dt = sigma_down - sigma
+
+ prev_sample = sample + derivative * dt
+
+ device = model_output.device
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
+ device
+ )
+ else:
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
+ device
+ )
+
+ prev_sample = prev_sample + noise * sigma_up
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return EulerAncestralDiscreteSchedulerOutput(
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
+ )
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ self.timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ schedule_timesteps = self.timesteps
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = self.sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_euler_discrete.py b/diffusers/schedulers/scheduling_euler_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cb4a1eaa565acbf51970911248e1bf0d604c979
--- /dev/null
+++ b/diffusers/schedulers/scheduling_euler_discrete.py
@@ -0,0 +1,287 @@
+# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
+from .scheduling_utils import SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
+class EulerDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original
+ k-diffusion implementation by Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+
+ """
+
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = self.sigmas.max()
+
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.is_scale_input_called = False
+
+ def scale_model_input(
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
+ ) -> torch.FloatTensor:
+ """
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ self.is_scale_input_called = True
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
+ else:
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ s_churn: float = 0.0,
+ s_tmin: float = 0.0,
+ s_tmax: float = float("inf"),
+ s_noise: float = 1.0,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`float`): current timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ s_churn (`float`)
+ s_tmin (`float`)
+ s_tmax (`float`)
+ s_noise (`float`)
+ generator (`torch.Generator`, optional): Random number generator.
+ return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep.",
+ )
+
+ if not self.is_scale_input_called:
+ logger.warning(
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
+ "See `StableDiffusionPipeline` for a usage example."
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
+
+ device = model_output.device
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
+ device
+ )
+ else:
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
+ device
+ )
+
+ eps = noise * s_noise
+ sigma_hat = sigma * (gamma + 1)
+
+ if gamma > 0:
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = sample - sigma_hat * model_output
+ elif self.config.prediction_type == "v_prediction":
+ # * c_out + input * c_skip
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma_hat
+
+ dt = self.sigmas[step_index + 1] - sigma_hat
+
+ prev_sample = sample + derivative * dt
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ self.timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ schedule_timesteps = self.timesteps
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = self.sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_heun_discrete.py b/diffusers/schedulers/scheduling_heun_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f40a24050b4604fd7b6af224bca4f65b075342d
--- /dev/null
+++ b/diffusers/schedulers/scheduling_heun_discrete.py
@@ -0,0 +1,264 @@
+# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original
+ k-diffusion implementation by Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
+ starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ """
+
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ order = 2
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085, # sensible defaults
+ beta_end: float = 0.012,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # set all values
+ self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
+
+ def index_for_timestep(self, timestep):
+ indices = (self.timesteps == timestep).nonzero()
+ if self.state_in_first_order:
+ pos = -1
+ else:
+ pos = 0
+ return indices[pos].item()
+
+ def scale_model_input(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ ) -> torch.FloatTensor:
+ """
+ Args:
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ sigma = self.sigmas[step_index]
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: Union[str, torch.device] = None,
+ num_train_timesteps: Optional[int] = None,
+ ):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
+
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ sigmas = torch.from_numpy(sigmas).to(device=device)
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = self.sigmas.max()
+
+ timesteps = torch.from_numpy(timesteps)
+ timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
+
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ self.timesteps = timesteps.to(device, dtype=torch.float32)
+ else:
+ self.timesteps = timesteps.to(device=device)
+
+ # empty dt and derivative
+ self.prev_derivative = None
+ self.dt = None
+
+ @property
+ def state_in_first_order(self):
+ return self.dt is None
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: Union[float, torch.FloatTensor],
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Args:
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
+ (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ if self.state_in_first_order:
+ sigma = self.sigmas[step_index]
+ sigma_next = self.sigmas[step_index + 1]
+ else:
+ # 2nd order / Heun's method
+ sigma = self.sigmas[step_index - 1]
+ sigma_next = self.sigmas[step_index]
+
+ # currently only gamma=0 is supported. This usually works best anyways.
+ # We can support gamma in the future but then need to scale the timestep before
+ # passing it to the model which requires a change in API
+ gamma = 0
+ sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_next
+ pred_original_sample = sample - sigma_input * model_output
+ elif self.config.prediction_type == "v_prediction":
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_next
+ pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
+ sample / (sigma_input**2 + 1)
+ )
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ if self.state_in_first_order:
+ # 2. Convert to an ODE derivative for 1st order
+ derivative = (sample - pred_original_sample) / sigma_hat
+ # 3. delta timestep
+ dt = sigma_next - sigma_hat
+
+ # store for 2nd order step
+ self.prev_derivative = derivative
+ self.dt = dt
+ self.sample = sample
+ else:
+ # 2. 2nd order / Heun's method
+ derivative = (sample - pred_original_sample) / sigma_next
+ derivative = (self.prev_derivative + derivative) / 2
+
+ # 3. take prev timestep & sample
+ dt = self.dt
+ sample = self.sample
+
+ # free dt and derivative
+ # Note, this puts the scheduler in "first order mode"
+ self.prev_derivative = None
+ self.dt = None
+ self.sample = None
+
+ prev_sample = sample + derivative * dt
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ self.timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [self.index_for_timestep(t) for t in timesteps]
+
+ sigma = self.sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_ipndm.py b/diffusers/schedulers/scheduling_ipndm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f22261d3ecd258485d21a77a49e105cb02af15f5
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ipndm.py
@@ -0,0 +1,161 @@
+# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+class IPNDMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Improved Pseudo numerical methods for diffusion models (iPNDM) ported from @crowsonkb's amazing k-diffusion
+ [library](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296)
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2202.09778
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None
+ ):
+ # set `betas`, `alphas`, `timesteps`
+ self.set_timesteps(num_train_timesteps)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # For now we only support F-PNDM, i.e. the runge-kutta method
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
+ self.pndm_order = 4
+
+ # running values
+ self.ets = []
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ self.num_inference_steps = num_inference_steps
+ steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
+ steps = torch.cat([steps, torch.tensor([0.0])])
+
+ if self.config.trained_betas is not None:
+ self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32)
+ else:
+ self.betas = torch.sin(steps * math.pi / 2) ** 2
+
+ self.alphas = (1.0 - self.betas**2) ** 0.5
+
+ timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
+ self.timesteps = timesteps.to(device)
+
+ self.ets = []
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
+ times to approximate the solution.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ timestep_index = (self.timesteps == timestep).nonzero().item()
+ prev_timestep_index = timestep_index + 1
+
+ ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index]
+ self.ets.append(ets)
+
+ if len(self.ets) == 1:
+ ets = self.ets[-1]
+ elif len(self.ets) == 2:
+ ets = (3 * self.ets[-1] - self.ets[-2]) / 2
+ elif len(self.ets) == 3:
+ ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
+ else:
+ ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
+
+ prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets)
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
+ alpha = self.alphas[timestep_index]
+ sigma = self.betas[timestep_index]
+
+ next_alpha = self.alphas[prev_timestep_index]
+ next_sigma = self.betas[prev_timestep_index]
+
+ pred = (sample - sigma * ets) / max(alpha, 1e-8)
+ prev_sample = next_alpha * pred + ets * next_sigma
+
+ return prev_sample
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7d2175f027a6e83d5b77824ea1edd309ae76128
--- /dev/null
+++ b/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
@@ -0,0 +1,324 @@
+# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see:
+ https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188
+
+ Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022).
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
+ starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ """
+
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ order = 2
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085, # sensible defaults
+ beta_end: float = 0.012,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # set all values
+ self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
+
+ def index_for_timestep(self, timestep):
+ indices = (self.timesteps == timestep).nonzero()
+ if self.state_in_first_order:
+ pos = -1
+ else:
+ pos = 0
+ return indices[pos].item()
+
+ def scale_model_input(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ ) -> torch.FloatTensor:
+ """
+ Args:
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ if self.state_in_first_order:
+ sigma = self.sigmas[step_index]
+ else:
+ sigma = self.sigmas_interpol[step_index - 1]
+
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: Union[str, torch.device] = None,
+ num_train_timesteps: Optional[int] = None,
+ ):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
+
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
+
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ sigmas = torch.from_numpy(sigmas).to(device=device)
+
+ # compute up and down sigmas
+ sigmas_next = sigmas.roll(-1)
+ sigmas_next[-1] = 0.0
+ sigmas_up = (sigmas_next**2 * (sigmas**2 - sigmas_next**2) / sigmas**2) ** 0.5
+ sigmas_down = (sigmas_next**2 - sigmas_up**2) ** 0.5
+ sigmas_down[-1] = 0.0
+
+ # compute interpolated sigmas
+ sigmas_interpol = sigmas.log().lerp(sigmas_down.log(), 0.5).exp()
+ sigmas_interpol[-2:] = 0.0
+
+ # set sigmas
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]])
+ self.sigmas_interpol = torch.cat(
+ [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
+ )
+ self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
+ self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = self.sigmas.max()
+
+ timesteps = torch.from_numpy(timesteps).to(device)
+ timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
+ interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
+ timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
+
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ self.timesteps = timesteps.to(device, dtype=torch.float32)
+ else:
+ self.timesteps = timesteps
+
+ self.sample = None
+
+ def sigma_to_t(self, sigma):
+ # get log sigma
+ log_sigma = sigma.log()
+
+ # get distribution
+ dists = log_sigma - self.log_sigmas[:, None]
+
+ # get sigmas range
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = self.log_sigmas[low_idx]
+ high = self.log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = w.clamp(0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.view(sigma.shape)
+ return t
+
+ @property
+ def state_in_first_order(self):
+ return self.sample is None
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: Union[float, torch.FloatTensor],
+ sample: Union[torch.FloatTensor, np.ndarray],
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Args:
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
+ (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ if self.state_in_first_order:
+ sigma = self.sigmas[step_index]
+ sigma_interpol = self.sigmas_interpol[step_index]
+ sigma_up = self.sigmas_up[step_index]
+ sigma_down = self.sigmas_down[step_index - 1]
+ else:
+ # 2nd order / KPDM2's method
+ sigma = self.sigmas[step_index - 1]
+ sigma_interpol = self.sigmas_interpol[step_index - 1]
+ sigma_up = self.sigmas_up[step_index - 1]
+ sigma_down = self.sigmas_down[step_index - 1]
+
+ # currently only gamma=0 is supported. This usually works best anyways.
+ # We can support gamma in the future but then need to scale the timestep before
+ # passing it to the model which requires a change in API
+ gamma = 0
+ sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
+
+ device = model_output.device
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
+ device
+ )
+ else:
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
+ device
+ )
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
+ pred_original_sample = sample - sigma_input * model_output
+ elif self.config.prediction_type == "v_prediction":
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
+ pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
+ sample / (sigma_input**2 + 1)
+ )
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ if self.state_in_first_order:
+ # 2. Convert to an ODE derivative for 1st order
+ derivative = (sample - pred_original_sample) / sigma_hat
+ # 3. delta timestep
+ dt = sigma_interpol - sigma_hat
+
+ # store for 2nd order step
+ self.sample = sample
+ self.dt = dt
+ prev_sample = sample + derivative * dt
+ else:
+ # DPM-Solver-2
+ # 2. Convert to an ODE derivative for 2nd order
+ derivative = (sample - pred_original_sample) / sigma_interpol
+ # 3. delta timestep
+ dt = sigma_down - sigma_hat
+
+ sample = self.sample
+ self.sample = None
+
+ prev_sample = sample + derivative * dt
+ prev_sample = prev_sample + noise * sigma_up
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ self.timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [self.index_for_timestep(t) for t in timesteps]
+
+ sigma = self.sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aee346c574c14d52c6b67a2c2275cedc3f6a2cc
--- /dev/null
+++ b/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
@@ -0,0 +1,299 @@
+# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see:
+ https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188
+
+ Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022).
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
+ starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ """
+
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ order = 2
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085, # sensible defaults
+ beta_end: float = 0.012,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # set all values
+ self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
+
+ def index_for_timestep(self, timestep):
+ indices = (self.timesteps == timestep).nonzero()
+ if self.state_in_first_order:
+ pos = -1
+ else:
+ pos = 0
+ return indices[pos].item()
+
+ def scale_model_input(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ ) -> torch.FloatTensor:
+ """
+ Args:
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ if self.state_in_first_order:
+ sigma = self.sigmas[step_index]
+ else:
+ sigma = self.sigmas_interpol[step_index]
+
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: Union[str, torch.device] = None,
+ num_train_timesteps: Optional[int] = None,
+ ):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
+
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
+
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ sigmas = torch.from_numpy(sigmas).to(device=device)
+
+ # interpolate sigmas
+ sigmas_interpol = sigmas.log().lerp(sigmas.roll(1).log(), 0.5).exp()
+
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]])
+ self.sigmas_interpol = torch.cat(
+ [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
+ )
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = self.sigmas.max()
+
+ timesteps = torch.from_numpy(timesteps).to(device)
+
+ # interpolate timesteps
+ timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
+ interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
+ timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
+
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ self.timesteps = timesteps.to(torch.float32)
+ else:
+ self.timesteps = timesteps
+
+ self.sample = None
+
+ def sigma_to_t(self, sigma):
+ # get log sigma
+ log_sigma = sigma.log()
+
+ # get distribution
+ dists = log_sigma - self.log_sigmas[:, None]
+
+ # get sigmas range
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = self.log_sigmas[low_idx]
+ high = self.log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = w.clamp(0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.view(sigma.shape)
+ return t
+
+ @property
+ def state_in_first_order(self):
+ return self.sample is None
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: Union[float, torch.FloatTensor],
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Args:
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
+ (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ step_index = self.index_for_timestep(timestep)
+
+ if self.state_in_first_order:
+ sigma = self.sigmas[step_index]
+ sigma_interpol = self.sigmas_interpol[step_index + 1]
+ sigma_next = self.sigmas[step_index + 1]
+ else:
+ # 2nd order / KDPM2's method
+ sigma = self.sigmas[step_index - 1]
+ sigma_interpol = self.sigmas_interpol[step_index]
+ sigma_next = self.sigmas[step_index]
+
+ # currently only gamma=0 is supported. This usually works best anyways.
+ # We can support gamma in the future but then need to scale the timestep before
+ # passing it to the model which requires a change in API
+ gamma = 0
+ sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
+ pred_original_sample = sample - sigma_input * model_output
+ elif self.config.prediction_type == "v_prediction":
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
+ pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
+ sample / (sigma_input**2 + 1)
+ )
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ if self.state_in_first_order:
+ # 2. Convert to an ODE derivative for 1st order
+ derivative = (sample - pred_original_sample) / sigma_hat
+ # 3. delta timestep
+ dt = sigma_interpol - sigma_hat
+
+ # store for 2nd order step
+ self.sample = sample
+ else:
+ # DPM-Solver-2
+ # 2. Convert to an ODE derivative for 2nd order
+ derivative = (sample - pred_original_sample) / sigma_interpol
+
+ # 3. delta timestep
+ dt = sigma_next - sigma_hat
+
+ sample = self.sample
+ self.sample = None
+
+ prev_sample = sample + derivative * dt
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ self.timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [self.index_for_timestep(t) for t in timesteps]
+
+ sigma = self.sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_karras_ve.py b/diffusers/schedulers/scheduling_karras_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..41a73b3ac36e8985a3e1cf781afc06b0e6f6ed48
--- /dev/null
+++ b/diffusers/schedulers/scheduling_karras_ve.py
@@ -0,0 +1,232 @@
+# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+class KarrasVeOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Derivative of predicted original image sample (x_0).
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ derivative: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
+ the VE column of Table 1 from [1] for reference.
+
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
+ https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
+ Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
+ optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
+
+ Args:
+ sigma_min (`float`): minimum noise magnitude
+ sigma_max (`float`): maximum noise magnitude
+ s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
+ A reasonable range is [1.000, 1.011].
+ s_churn (`float`): the parameter controlling the overall amount of stochasticity.
+ A reasonable range is [0, 100].
+ s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
+ A reasonable range is [0, 10].
+ s_max (`float`): the end value of the sigma range where we add noise.
+ A reasonable range is [0.2, 80].
+
+ """
+
+ order = 2
+
+ @register_to_config
+ def __init__(
+ self,
+ sigma_min: float = 0.02,
+ sigma_max: float = 100,
+ s_noise: float = 1.007,
+ s_churn: float = 80,
+ s_min: float = 0.05,
+ s_max: float = 50,
+ ):
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = sigma_max
+
+ # setable values
+ self.num_inference_steps: int = None
+ self.timesteps: np.IntTensor = None
+ self.schedule: torch.FloatTensor = None # sigma(t_i)
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+
+ """
+ self.num_inference_steps = num_inference_steps
+ timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+ schedule = [
+ (
+ self.config.sigma_max**2
+ * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
+ )
+ for i in self.timesteps
+ ]
+ self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
+
+ def add_noise_to_input(
+ self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
+ ) -> Tuple[torch.FloatTensor, float]:
+ """
+ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
+ higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
+
+ TODO Args:
+ """
+ if self.config.s_min <= sigma <= self.config.s_max:
+ gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
+ else:
+ gamma = 0
+
+ # sample eps ~ N(0, S_noise^2 * I)
+ eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
+ sigma_hat = sigma + gamma * sigma
+ sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
+
+ return sample_hat, sigma_hat
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[KarrasVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor`): TODO
+ return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
+
+ KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
+ Returns:
+ [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`:
+ [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+
+ pred_original_sample = sample_hat + sigma_hat * model_output
+ derivative = (sample_hat - pred_original_sample) / sigma_hat
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
+
+ if not return_dict:
+ return (sample_prev, derivative)
+
+ return KarrasVeOutput(
+ prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
+ )
+
+ def step_correct(
+ self,
+ model_output: torch.FloatTensor,
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: torch.FloatTensor,
+ sample_prev: torch.FloatTensor,
+ derivative: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[KarrasVeOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. TODO complete description
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor`): TODO
+ sample_prev (`torch.FloatTensor`): TODO
+ derivative (`torch.FloatTensor`): TODO
+ return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
+
+ Returns:
+ prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
+
+ """
+ pred_original_sample = sample_prev + sigma_prev * model_output
+ derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
+
+ if not return_dict:
+ return (sample_prev, derivative)
+
+ return KarrasVeOutput(
+ prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
+ )
+
+ def add_noise(self, original_samples, noise, timesteps):
+ raise NotImplementedError()
diff --git a/diffusers/schedulers/scheduling_karras_ve_flax.py b/diffusers/schedulers/scheduling_karras_ve_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4e612c3cc84021f2959b16984442ba8de184fa7
--- /dev/null
+++ b/diffusers/schedulers/scheduling_karras_ve_flax.py
@@ -0,0 +1,237 @@
+# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import flax
+import jax.numpy as jnp
+from jax import random
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils_flax import FlaxSchedulerMixin
+
+
+@flax.struct.dataclass
+class KarrasVeSchedulerState:
+ # setable values
+ num_inference_steps: Optional[int] = None
+ timesteps: Optional[jnp.ndarray] = None
+ schedule: Optional[jnp.ndarray] = None # sigma(t_i)
+
+ @classmethod
+ def create(cls):
+ return cls()
+
+
+@dataclass
+class FlaxKarrasVeOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
+ Derivative of predicted original image sample (x_0).
+ state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
+ """
+
+ prev_sample: jnp.ndarray
+ derivative: jnp.ndarray
+ state: KarrasVeSchedulerState
+
+
+class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
+ the VE column of Table 1 from [1] for reference.
+
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
+ https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
+ Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
+ optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
+
+ Args:
+ sigma_min (`float`): minimum noise magnitude
+ sigma_max (`float`): maximum noise magnitude
+ s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
+ A reasonable range is [1.000, 1.011].
+ s_churn (`float`): the parameter controlling the overall amount of stochasticity.
+ A reasonable range is [0, 100].
+ s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
+ A reasonable range is [0, 10].
+ s_max (`float`): the end value of the sigma range where we add noise.
+ A reasonable range is [0.2, 80].
+ """
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ sigma_min: float = 0.02,
+ sigma_max: float = 100,
+ s_noise: float = 1.007,
+ s_churn: float = 80,
+ s_min: float = 0.05,
+ s_max: float = 50,
+ ):
+ pass
+
+ def create_state(self):
+ return KarrasVeSchedulerState.create()
+
+ def set_timesteps(
+ self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = ()
+ ) -> KarrasVeSchedulerState:
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`KarrasVeSchedulerState`):
+ the `FlaxKarrasVeScheduler` state data class.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+
+ """
+ timesteps = jnp.arange(0, num_inference_steps)[::-1].copy()
+ schedule = [
+ (
+ self.config.sigma_max**2
+ * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
+ )
+ for i in timesteps
+ ]
+
+ return state.replace(
+ num_inference_steps=num_inference_steps,
+ schedule=jnp.array(schedule, dtype=jnp.float32),
+ timesteps=timesteps,
+ )
+
+ def add_noise_to_input(
+ self,
+ state: KarrasVeSchedulerState,
+ sample: jnp.ndarray,
+ sigma: float,
+ key: random.KeyArray,
+ ) -> Tuple[jnp.ndarray, float]:
+ """
+ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
+ higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
+
+ TODO Args:
+ """
+ if self.config.s_min <= sigma <= self.config.s_max:
+ gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1)
+ else:
+ gamma = 0
+
+ # sample eps ~ N(0, S_noise^2 * I)
+ key = random.split(key, num=1)
+ eps = self.config.s_noise * random.normal(key=key, shape=sample.shape)
+ sigma_hat = sigma + gamma * sigma
+ sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
+
+ return sample_hat, sigma_hat
+
+ def step(
+ self,
+ state: KarrasVeSchedulerState,
+ model_output: jnp.ndarray,
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: jnp.ndarray,
+ return_dict: bool = True,
+ ) -> Union[FlaxKarrasVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
+ return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
+
+ Returns:
+ [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
+ chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+ """
+
+ pred_original_sample = sample_hat + sigma_hat * model_output
+ derivative = (sample_hat - pred_original_sample) / sigma_hat
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
+
+ if not return_dict:
+ return (sample_prev, derivative, state)
+
+ return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
+
+ def step_correct(
+ self,
+ state: KarrasVeSchedulerState,
+ model_output: jnp.ndarray,
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: jnp.ndarray,
+ sample_prev: jnp.ndarray,
+ derivative: jnp.ndarray,
+ return_dict: bool = True,
+ ) -> Union[FlaxKarrasVeOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. TODO complete description
+
+ Args:
+ state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
+ sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
+ derivative (`torch.FloatTensor` or `np.ndarray`): TODO
+ return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
+
+ Returns:
+ prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
+
+ """
+ pred_original_sample = sample_prev + sigma_prev * model_output
+ derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
+
+ if not return_dict:
+ return (sample_prev, derivative, state)
+
+ return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
+
+ def add_noise(self, original_samples, noise, timesteps):
+ raise NotImplementedError()
diff --git a/diffusers/schedulers/scheduling_lms_discrete.py b/diffusers/schedulers/scheduling_lms_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..28bc9bd0c608650ba67982b4eb408bab9c215ba1
--- /dev/null
+++ b/diffusers/schedulers/scheduling_lms_discrete.py
@@ -0,0 +1,278 @@
+# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import warnings
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from scipy import integrate
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete
+class LMSDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
+ Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ """
+
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = self.sigmas.max()
+
+ # setable values
+ self.num_inference_steps = None
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
+ self.derivatives = []
+ self.is_scale_input_called = False
+
+ def scale_model_input(
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
+ ) -> torch.FloatTensor:
+ """
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ self.is_scale_input_called = True
+ return sample
+
+ def get_lms_coefficient(self, order, t, current_order):
+ """
+ Compute a linear multistep coefficient.
+
+ Args:
+ order (TODO):
+ t (TODO):
+ current_order (TODO):
+ """
+
+ def lms_derivative(tau):
+ prod = 1.0
+ for k in range(order):
+ if current_order == k:
+ continue
+ prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
+ return prod
+
+ integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]
+
+ return integrated_coeff
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
+ if str(device).startswith("mps"):
+ # mps does not support float64
+ self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
+ else:
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
+
+ self.derivatives = []
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ order: int = 4,
+ return_dict: bool = True,
+ ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`float`): current timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ order: coefficient for multi-step inference.
+ return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
+ When returning a tuple, the first element is the sample tensor.
+
+ """
+ if not self.is_scale_input_called:
+ warnings.warn(
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
+ "See `StableDiffusionPipeline` for a usage example."
+ )
+
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ step_index = (self.timesteps == timestep).nonzero().item()
+ sigma = self.sigmas[step_index]
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = sample - sigma * model_output
+ elif self.config.prediction_type == "v_prediction":
+ # * c_out + input * c_skip
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma
+ self.derivatives.append(derivative)
+ if len(self.derivatives) > order:
+ self.derivatives.pop(0)
+
+ # 3. Compute linear multistep coefficients
+ order = min(step_index + 1, order)
+ lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]
+
+ # 4. Compute previous sample based on the derivatives path
+ prev_sample = sample + sum(
+ coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
+ )
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_lms_discrete_flax.py b/diffusers/schedulers/scheduling_lms_discrete_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..5da43be2ada3d471e4c146538c64d50c3700161f
--- /dev/null
+++ b/diffusers/schedulers/scheduling_lms_discrete_flax.py
@@ -0,0 +1,242 @@
+# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import flax
+import jax.numpy as jnp
+from scipy import integrate
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils_flax import (
+ _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ broadcast_to_shape_from_left,
+)
+
+
+@flax.struct.dataclass
+class LMSDiscreteSchedulerState:
+ # setable values
+ num_inference_steps: Optional[int] = None
+ timesteps: Optional[jnp.ndarray] = None
+ sigmas: Optional[jnp.ndarray] = None
+ derivatives: jnp.ndarray = jnp.array([])
+
+ @classmethod
+ def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray):
+ return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], sigmas=sigmas)
+
+
+@dataclass
+class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
+ state: LMSDiscreteSchedulerState
+
+
+class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
+ Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`jnp.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ """
+
+ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[jnp.ndarray] = None,
+ ):
+ if trained_betas is not None:
+ self.betas = jnp.asarray(trained_betas)
+ elif beta_schedule == "linear":
+ self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
+
+ def create_state(self):
+ self.state = LMSDiscreteSchedulerState.create(
+ num_train_timesteps=self.config.num_train_timesteps,
+ sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
+ )
+
+ def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
+ """
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
+
+ Args:
+ state (`LMSDiscreteSchedulerState`):
+ the `FlaxLMSDiscreteScheduler` state data class instance.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ timestep (`int`):
+ current discrete timestep in the diffusion chain.
+
+ Returns:
+ `jnp.ndarray`: scaled input sample
+ """
+ (step_index,) = jnp.where(state.timesteps == timestep, size=1)
+ sigma = state.sigmas[step_index]
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+ return sample
+
+ def get_lms_coefficient(self, state, order, t, current_order):
+ """
+ Compute a linear multistep coefficient.
+
+ Args:
+ order (TODO):
+ t (TODO):
+ current_order (TODO):
+ """
+
+ def lms_derivative(tau):
+ prod = 1.0
+ for k in range(order):
+ if current_order == k:
+ continue
+ prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k])
+ return prod
+
+ integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0]
+
+ return integrated_coeff
+
+ def set_timesteps(
+ self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
+ ) -> LMSDiscreteSchedulerState:
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`LMSDiscreteSchedulerState`):
+ the `FlaxLMSDiscreteScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=jnp.float32)
+
+ low_idx = jnp.floor(timesteps).astype(int)
+ high_idx = jnp.ceil(timesteps).astype(int)
+ frac = jnp.mod(timesteps, 1.0)
+ sigmas = jnp.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
+ sigmas = jnp.concatenate([sigmas, jnp.array([0.0])]).astype(jnp.float32)
+
+ return state.replace(
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps.astype(int),
+ derivatives=jnp.array([]),
+ sigmas=sigmas,
+ )
+
+ def step(
+ self,
+ state: LMSDiscreteSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ order: int = 4,
+ return_dict: bool = True,
+ ) -> Union[FlaxLMSSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ order: coefficient for multi-step inference.
+ return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class
+
+ Returns:
+ [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ sigma = state.sigmas[timestep]
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ pred_original_sample = sample - sigma * model_output
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma
+ state = state.replace(derivatives=jnp.append(state.derivatives, derivative))
+ if len(state.derivatives) > order:
+ state = state.replace(derivatives=jnp.delete(state.derivatives, 0))
+
+ # 3. Compute linear multistep coefficients
+ order = min(timestep + 1, order)
+ lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)]
+
+ # 4. Compute previous sample based on the derivatives path
+ prev_sample = sample + sum(
+ coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives))
+ )
+
+ if not return_dict:
+ return (prev_sample, state)
+
+ return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
+
+ def add_noise(
+ self,
+ state: LMSDiscreteSchedulerState,
+ original_samples: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ sigma = state.sigmas[timesteps].flatten()
+ sigma = broadcast_to_shape_from_left(sigma, noise.shape)
+
+ noisy_samples = original_samples + noise * sigma
+
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_pndm.py b/diffusers/schedulers/scheduling_pndm.py
new file mode 100644
index 0000000000000000000000000000000000000000..a29f7d6d44cc628ac64bcb7225c5c494d4c70131
--- /dev/null
+++ b/diffusers/schedulers/scheduling_pndm.py
@@ -0,0 +1,425 @@
+# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class PNDMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
+ namely Runge-Kutta method and a linear multi-step method.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2202.09778
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ skip_prk_steps (`bool`):
+ allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
+ before plms steps; defaults to `False`.
+ set_alpha_to_one (`bool`, default `False`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+
+ """
+
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ skip_prk_steps: bool = False,
+ set_alpha_to_one: bool = False,
+ prediction_type: str = "epsilon",
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # For now we only support F-PNDM, i.e. the runge-kutta method
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
+ self.pndm_order = 4
+
+ # running values
+ self.cur_model_output = 0
+ self.counter = 0
+ self.cur_sample = None
+ self.ets = []
+
+ # setable values
+ self.num_inference_steps = None
+ self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ self.prk_timesteps = None
+ self.plms_timesteps = None
+ self.timesteps = None
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ self.num_inference_steps = num_inference_steps
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
+ self._timesteps += self.config.steps_offset
+
+ if self.config.skip_prk_steps:
+ # for some models like stable diffusion the prk steps can/should be skipped to
+ # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
+ # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
+ self.prk_timesteps = np.array([])
+ self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
+ ::-1
+ ].copy()
+ else:
+ prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
+ np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
+ )
+ self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
+ self.plms_timesteps = self._timesteps[:-3][
+ ::-1
+ ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
+
+ timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ self.ets = []
+ self.counter = 0
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
+ return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
+ else:
+ return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
+
+ def step_prk(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
+ solution to the differential equation.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
+ prev_timestep = timestep - diff_to_prev
+ timestep = self.prk_timesteps[self.counter // 4 * 4]
+
+ if self.counter % 4 == 0:
+ self.cur_model_output += 1 / 6 * model_output
+ self.ets.append(model_output)
+ self.cur_sample = sample
+ elif (self.counter - 1) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 2) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 3) % 4 == 0:
+ model_output = self.cur_model_output + 1 / 6 * model_output
+ self.cur_model_output = 0
+
+ # cur_sample should not be `None`
+ cur_sample = self.cur_sample if self.cur_sample is not None else sample
+
+ prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def step_plms(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
+ times to approximate the solution.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if not self.config.skip_prk_steps and len(self.ets) < 3:
+ raise ValueError(
+ f"{self.__class__} can only be run AFTER scheduler has been run "
+ "in 'prk' mode for at least 12 iterations "
+ "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
+ "for more information."
+ )
+
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ if self.counter != 1:
+ self.ets = self.ets[-3:]
+ self.ets.append(model_output)
+ else:
+ prev_timestep = timestep
+ timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
+
+ if len(self.ets) == 1 and self.counter == 0:
+ model_output = model_output
+ self.cur_sample = sample
+ elif len(self.ets) == 1 and self.counter == 1:
+ model_output = (model_output + self.ets[-1]) / 2
+ sample = self.cur_sample
+ self.cur_sample = None
+ elif len(self.ets) == 2:
+ model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
+ elif len(self.ets) == 3:
+ model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
+ else:
+ model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
+
+ prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
+ # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
+ # this function computes x_(t−δ) using the formula of (9)
+ # Note that x_t needs to be added to both sides of the equation
+
+ # Notation ( ->
+ # alpha_prod_t -> α_t
+ # alpha_prod_t_prev -> α_(t−δ)
+ # beta_prod_t -> (1 - α_t)
+ # beta_prod_t_prev -> (1 - α_(t−δ))
+ # sample -> x_t
+ # model_output -> e_θ(x_t, t)
+ # prev_sample -> x_(t−δ)
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ if self.config.prediction_type == "v_prediction":
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ elif self.config.prediction_type != "epsilon":
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
+ )
+
+ # corresponds to (α_(t−δ) - α_t) divided by
+ # denominator of x_t in formula (9) and plus 1
+ # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
+ # sqrt(α_(t−δ)) / sqrt(α_t))
+ sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
+
+ # corresponds to denominator of e_θ(x_t, t) in formula (9)
+ model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
+ alpha_prod_t * beta_prod_t * alpha_prod_t_prev
+ ) ** (0.5)
+
+ # full formula (9)
+ prev_sample = (
+ sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
+ )
+
+ return prev_sample
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_pndm_flax.py b/diffusers/schedulers/scheduling_pndm_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..298e62de20d15febcd44b00f87046c431f4e2337
--- /dev/null
+++ b/diffusers/schedulers/scheduling_pndm_flax.py
@@ -0,0 +1,531 @@
+# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import flax
+import jax
+import jax.numpy as jnp
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils_flax import (
+ _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ broadcast_to_shape_from_left,
+)
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return jnp.array(betas, dtype=jnp.float32)
+
+
+@flax.struct.dataclass
+class PNDMSchedulerState:
+ # setable values
+ _timesteps: jnp.ndarray
+ num_inference_steps: Optional[int] = None
+ prk_timesteps: Optional[jnp.ndarray] = None
+ plms_timesteps: Optional[jnp.ndarray] = None
+ timesteps: Optional[jnp.ndarray] = None
+
+ # running values
+ cur_model_output: Optional[jnp.ndarray] = None
+ counter: int = 0
+ cur_sample: Optional[jnp.ndarray] = None
+ ets: jnp.ndarray = jnp.array([])
+
+ @classmethod
+ def create(cls, num_train_timesteps: int):
+ return cls(_timesteps=jnp.arange(0, num_train_timesteps)[::-1])
+
+
+@dataclass
+class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput):
+ state: PNDMSchedulerState
+
+
+class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
+ namely Runge-Kutta method and a linear multi-step method.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2202.09778
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`jnp.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ skip_prk_steps (`bool`):
+ allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
+ before plms steps; defaults to `False`.
+ set_alpha_to_one (`bool`, default `False`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ """
+
+ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[jnp.ndarray] = None,
+ skip_prk_steps: bool = False,
+ set_alpha_to_one: bool = False,
+ steps_offset: int = 0,
+ ):
+ if trained_betas is not None:
+ self.betas = jnp.asarray(trained_betas)
+ elif beta_schedule == "linear":
+ self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
+
+ self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # For now we only support F-PNDM, i.e. the runge-kutta method
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
+ self.pndm_order = 4
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ def create_state(self):
+ return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
+
+ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`PNDMSchedulerState`):
+ the `FlaxPNDMScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ shape (`Tuple`):
+ the shape of the samples to be generated.
+ """
+ offset = self.config.steps_offset
+
+ step_ratio = self.config.num_train_timesteps // num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # rounding to avoid issues when num_inference_step is power of 3
+ _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset
+
+ state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps)
+
+ if self.config.skip_prk_steps:
+ # for some models like stable diffusion the prk steps can/should be skipped to
+ # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
+ # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
+ state = state.replace(
+ prk_timesteps=jnp.array([]),
+ plms_timesteps=jnp.concatenate(
+ [state._timesteps[:-1], state._timesteps[-2:-1], state._timesteps[-1:]]
+ )[::-1],
+ )
+ else:
+ prk_timesteps = jnp.array(state._timesteps[-self.pndm_order :]).repeat(2) + jnp.tile(
+ jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
+ )
+
+ state = state.replace(
+ prk_timesteps=(prk_timesteps[:-1].repeat(2)[1:-1])[::-1],
+ plms_timesteps=state._timesteps[:-3][::-1],
+ )
+
+ return state.replace(
+ timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int32),
+ counter=0,
+ # Reserve space for the state variables
+ cur_model_output=jnp.zeros(shape),
+ cur_sample=jnp.zeros(shape),
+ ets=jnp.zeros((4,) + shape),
+ )
+
+ def scale_model_input(
+ self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
+ ) -> jnp.ndarray:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
+ sample (`jnp.ndarray`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `jnp.ndarray`: scaled input sample
+ """
+ return sample
+
+ def step(
+ self,
+ state: PNDMSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ return_dict: bool = True,
+ ) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
+
+ Args:
+ state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
+
+ Returns:
+ [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.config.skip_prk_steps:
+ prev_sample, state = self.step_plms(
+ state=state, model_output=model_output, timestep=timestep, sample=sample
+ )
+ else:
+ prev_sample, state = jax.lax.switch(
+ jnp.where(state.counter < len(state.prk_timesteps), 0, 1),
+ (self.step_prk, self.step_plms),
+ # Args to either branch
+ state,
+ model_output,
+ timestep,
+ sample,
+ )
+
+ if not return_dict:
+ return (prev_sample, state)
+
+ return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state)
+
+ def step_prk(
+ self,
+ state: PNDMSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ ) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
+ solution to the differential equation.
+
+ Args:
+ state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
+
+ Returns:
+ [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if state.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ diff_to_prev = jnp.where(
+ state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
+ )
+ prev_timestep = timestep - diff_to_prev
+ timestep = state.prk_timesteps[state.counter // 4 * 4]
+
+ def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
+ return (
+ state.replace(
+ cur_model_output=state.cur_model_output + 1 / 6 * model_output,
+ ets=state.ets.at[ets_at].set(model_output),
+ cur_sample=sample,
+ ),
+ model_output,
+ )
+
+ def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
+ return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
+
+ def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
+ return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
+
+ def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
+ model_output = state.cur_model_output + 1 / 6 * model_output
+ return state.replace(cur_model_output=jnp.zeros_like(state.cur_model_output)), model_output
+
+ state, model_output = jax.lax.switch(
+ state.counter % 4,
+ (remainder_0, remainder_1, remainder_2, remainder_3),
+ # Args to either branch
+ state,
+ model_output,
+ state.counter // 4,
+ )
+
+ cur_sample = state.cur_sample
+ prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
+ state = state.replace(counter=state.counter + 1)
+
+ return (prev_sample, state)
+
+ def step_plms(
+ self,
+ state: PNDMSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ ) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
+ times to approximate the solution.
+
+ Args:
+ state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
+
+ Returns:
+ [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if state.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if not self.config.skip_prk_steps and len(state.ets) < 3:
+ raise ValueError(
+ f"{self.__class__} can only be run AFTER scheduler has been run "
+ "in 'prk' mode for at least 12 iterations "
+ "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
+ "for more information."
+ )
+
+ prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
+ prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)
+
+ # Reference:
+ # if state.counter != 1:
+ # state.ets.append(model_output)
+ # else:
+ # prev_timestep = timestep
+ # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
+
+ prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
+ timestep = jnp.where(
+ state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
+ )
+
+ # Reference:
+ # if len(state.ets) == 1 and state.counter == 0:
+ # model_output = model_output
+ # state.cur_sample = sample
+ # elif len(state.ets) == 1 and state.counter == 1:
+ # model_output = (model_output + state.ets[-1]) / 2
+ # sample = state.cur_sample
+ # state.cur_sample = None
+ # elif len(state.ets) == 2:
+ # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
+ # elif len(state.ets) == 3:
+ # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
+ # else:
+ # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])
+
+ def counter_0(state: PNDMSchedulerState):
+ ets = state.ets.at[0].set(model_output)
+ return state.replace(
+ ets=ets,
+ cur_sample=sample,
+ cur_model_output=jnp.array(model_output, dtype=jnp.float32),
+ )
+
+ def counter_1(state: PNDMSchedulerState):
+ return state.replace(
+ cur_model_output=(model_output + state.ets[0]) / 2,
+ )
+
+ def counter_2(state: PNDMSchedulerState):
+ ets = state.ets.at[1].set(model_output)
+ return state.replace(
+ ets=ets,
+ cur_model_output=(3 * ets[1] - ets[0]) / 2,
+ cur_sample=sample,
+ )
+
+ def counter_3(state: PNDMSchedulerState):
+ ets = state.ets.at[2].set(model_output)
+ return state.replace(
+ ets=ets,
+ cur_model_output=(23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12,
+ cur_sample=sample,
+ )
+
+ def counter_other(state: PNDMSchedulerState):
+ ets = state.ets.at[3].set(model_output)
+ next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0])
+
+ ets = ets.at[0].set(ets[1])
+ ets = ets.at[1].set(ets[2])
+ ets = ets.at[2].set(ets[3])
+
+ return state.replace(
+ ets=ets,
+ cur_model_output=next_model_output,
+ cur_sample=sample,
+ )
+
+ counter = jnp.clip(state.counter, 0, 4)
+ state = jax.lax.switch(
+ counter,
+ [counter_0, counter_1, counter_2, counter_3, counter_other],
+ state,
+ )
+
+ sample = state.cur_sample
+ model_output = state.cur_model_output
+ prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
+ state = state.replace(counter=state.counter + 1)
+
+ return (prev_sample, state)
+
+ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
+ # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
+ # this function computes x_(t−δ) using the formula of (9)
+ # Note that x_t needs to be added to both sides of the equation
+
+ # Notation ( ->
+ # alpha_prod_t -> α_t
+ # alpha_prod_t_prev -> α_(t−δ)
+ # beta_prod_t -> (1 - α_t)
+ # beta_prod_t_prev -> (1 - α_(t−δ))
+ # sample -> x_t
+ # model_output -> e_θ(x_t, t)
+ # prev_sample -> x_(t−δ)
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # corresponds to (α_(t−δ) - α_t) divided by
+ # denominator of x_t in formula (9) and plus 1
+ # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
+ # sqrt(α_(t−δ)) / sqrt(α_t))
+ sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
+
+ # corresponds to denominator of e_θ(x_t, t) in formula (9)
+ model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
+ alpha_prod_t * beta_prod_t * alpha_prod_t_prev
+ ) ** (0.5)
+
+ # full formula (9)
+ prev_sample = (
+ sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
+ )
+
+ return prev_sample
+
+ def add_noise(
+ self,
+ original_samples: jnp.ndarray,
+ noise: jnp.ndarray,
+ timesteps: jnp.ndarray,
+ ) -> jnp.ndarray:
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
+
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_repaint.py b/diffusers/schedulers/scheduling_repaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b80181f438903a00fd20496d17a44a46c3cec46
--- /dev/null
+++ b/diffusers/schedulers/scheduling_repaint.py
@@ -0,0 +1,324 @@
+# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+class RePaintSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from
+ the current timestep. `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: torch.FloatTensor
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class RePaintScheduler(SchedulerMixin, ConfigMixin):
+ """
+ RePaint is a schedule for DDPM inpainting inside a given mask.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ eta (`float`):
+ The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and
+ 1.0 is DDPM scheduler respectively.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ variance_type (`str`):
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ eta: float = 0.0,
+ trained_betas: Optional[np.ndarray] = None,
+ clip_sample: bool = True,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.from_numpy(trained_betas)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ elif beta_schedule == "sigmoid":
+ # GeoDiff sigmoid schedule
+ betas = torch.linspace(-6, 6, num_train_timesteps)
+ self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ self.one = torch.tensor(1.0)
+
+ self.final_alpha_cumprod = torch.tensor(1.0)
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
+
+ self.eta = eta
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ jump_length: int = 10,
+ jump_n_sample: int = 10,
+ device: Union[str, torch.device] = None,
+ ):
+ num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
+ self.num_inference_steps = num_inference_steps
+
+ timesteps = []
+
+ jumps = {}
+ for j in range(0, num_inference_steps - jump_length, jump_length):
+ jumps[j] = jump_n_sample - 1
+
+ t = num_inference_steps
+ while t >= 1:
+ t = t - 1
+ timesteps.append(t)
+
+ if jumps.get(t, 0) > 0:
+ jumps[t] = jumps[t] - 1
+ for _ in range(jump_length):
+ t = t + 1
+ timesteps.append(t)
+
+ timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def _get_variance(self, t):
+ prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps
+
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from
+ # https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get
+ # previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add
+ # variance to pred_sample
+ # Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf
+ # without eta.
+ # variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ original_image: torch.FloatTensor,
+ mask: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[RePaintSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned
+ diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ original_image (`torch.FloatTensor`):
+ the original image to inpaint on.
+ mask (`torch.FloatTensor`):
+ the mask where 0.0 values define which part of the original image to inpaint (change).
+ generator (`torch.Generator`, *optional*): random number generator.
+ return_dict (`bool`): option for returning tuple rather than
+ DDPMSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ t = timestep
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+
+ # 3. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
+
+ # We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we
+ # substitute formula (7) in the algorithm coming from DDPM paper
+ # (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper.
+ # DDIM schedule gives the same results as DDPM with eta = 1.0
+ # Noise is being reused in 7. and 8., but no impact on quality has
+ # been observed.
+
+ # 5. Add noise
+ noise = torch.randn(
+ model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
+ )
+ std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
+
+ variance = 0
+ if t > 0 and self.eta > 0:
+ variance = std_dev_t * noise
+
+ # 6. compute "direction pointing to x_t" of formula (12)
+ # from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output
+
+ # 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
+
+ # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
+ prev_known_part = (alpha_prod_t**0.5) * original_image + ((1 - alpha_prod_t) ** 0.5) * noise
+
+ # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
+ pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part
+
+ if not return_dict:
+ return (
+ pred_prev_sample,
+ pred_original_sample,
+ )
+
+ return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
+
+ def undo_step(self, sample, timestep, generator=None):
+ n = self.config.num_train_timesteps // self.num_inference_steps
+
+ for i in range(n):
+ beta = self.betas[timestep + i]
+ noise = torch.randn(sample.shape, generator=generator, device=sample.device)
+
+ # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
+ sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
+
+ return sample
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.")
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_sde_ve.py b/diffusers/schedulers/scheduling_sde_ve.py
new file mode 100644
index 0000000000000000000000000000000000000000..89d3d4a5858785731c0d60bdc5118a092d26f335
--- /dev/null
+++ b/diffusers/schedulers/scheduling_sde_ve.py
@@ -0,0 +1,266 @@
+# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+@dataclass
+class SdeVeOutput(BaseOutput):
+ """
+ Output class for the ScoreSdeVeScheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
+ """
+
+ prev_sample: torch.FloatTensor
+ prev_sample_mean: torch.FloatTensor
+
+
+class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The variance exploding stochastic differential equation (SDE) scheduler.
+
+ For more information, see the original paper: https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ snr (`float`):
+ coefficient weighting the step from the model_output sample (from the network) to the random noise.
+ sigma_min (`float`):
+ initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
+ distribution of the data.
+ sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
+ sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
+ epsilon.
+ correct_steps (`int`): number of correction steps performed on a produced sample.
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 2000,
+ snr: float = 0.15,
+ sigma_min: float = 0.01,
+ sigma_max: float = 1348.0,
+ sampling_eps: float = 1e-5,
+ correct_steps: int = 1,
+ ):
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = sigma_max
+
+ # setable values
+ self.timesteps = None
+
+ self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(
+ self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
+ ):
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
+
+ """
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+
+ self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
+
+ def set_sigmas(
+ self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
+ ):
+ """
+ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
+
+ The sigmas control the weight of the `drift` and `diffusion` components of sample update.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sigma_min (`float`, optional):
+ initial noise scale value (overrides value given at Scheduler instantiation).
+ sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
+ sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
+
+ """
+ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
+ sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+ if self.timesteps is None:
+ self.set_timesteps(num_inference_steps, sampling_eps)
+
+ self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
+ self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
+ self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
+
+ def get_adjacent_sigma(self, timesteps, t):
+ return torch.where(
+ timesteps == 0,
+ torch.zeros_like(t.to(timesteps.device)),
+ self.discrete_sigmas[timesteps - 1].to(timesteps.device),
+ )
+
+ def step_pred(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[SdeVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ timestep = timestep * torch.ones(
+ sample.shape[0], device=sample.device
+ ) # torch.repeat_interleave(timestep, sample.shape[0])
+ timesteps = (timestep * (len(self.timesteps) - 1)).long()
+
+ # mps requires indices to be in the same device, so we use cpu as is the default with cuda
+ timesteps = timesteps.to(self.discrete_sigmas.device)
+
+ sigma = self.discrete_sigmas[timesteps].to(sample.device)
+ adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
+ drift = torch.zeros_like(sample)
+ diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
+
+ # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
+ # also equation 47 shows the analog from SDE models to ancestral sampling methods
+ diffusion = diffusion.flatten()
+ while len(diffusion.shape) < len(sample.shape):
+ diffusion = diffusion.unsqueeze(-1)
+ drift = drift - diffusion**2 * model_output
+
+ # equation 6: sample noise for the diffusion term of
+ noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
+ prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
+ # TODO is the variable diffusion the correct scaling term for the noise?
+ prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
+
+ if not return_dict:
+ return (prev_sample, prev_sample_mean)
+
+ return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
+
+ def step_correct(
+ self,
+ model_output: torch.FloatTensor,
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
+ after making the prediction for the previous timestep.
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
+ # sample noise for correction
+ noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
+
+ # compute step size from the model_output, the noise, and the snr
+ grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
+ noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
+ step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
+ step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
+ # self.repeat_scalar(step_size, sample.shape[0])
+
+ # compute corrected sample: model_output term and noise term
+ step_size = step_size.flatten()
+ while len(step_size.shape) < len(sample.shape):
+ step_size = step_size.unsqueeze(-1)
+ prev_sample_mean = sample + step_size * model_output
+ prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_sde_ve_flax.py b/diffusers/schedulers/scheduling_sde_ve_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1f762bc90c471d6bbc7f33e5854d014b1e25667
--- /dev/null
+++ b/diffusers/schedulers/scheduling_sde_ve_flax.py
@@ -0,0 +1,276 @@
+# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import flax
+import jax.numpy as jnp
+from jax import random
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
+
+
+@flax.struct.dataclass
+class ScoreSdeVeSchedulerState:
+ # setable values
+ timesteps: Optional[jnp.ndarray] = None
+ discrete_sigmas: Optional[jnp.ndarray] = None
+ sigmas: Optional[jnp.ndarray] = None
+
+ @classmethod
+ def create(cls):
+ return cls()
+
+
+@dataclass
+class FlaxSdeVeOutput(FlaxSchedulerOutput):
+ """
+ Output class for the ScoreSdeVeScheduler's step function output.
+
+ Args:
+ state (`ScoreSdeVeSchedulerState`):
+ prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
+ Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
+ """
+
+ state: ScoreSdeVeSchedulerState
+ prev_sample: jnp.ndarray
+ prev_sample_mean: Optional[jnp.ndarray] = None
+
+
+class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
+ """
+ The variance exploding stochastic differential equation (SDE) scheduler.
+
+ For more information, see the original paper: https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ snr (`float`):
+ coefficient weighting the step from the model_output sample (from the network) to the random noise.
+ sigma_min (`float`):
+ initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
+ distribution of the data.
+ sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
+ sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
+ epsilon.
+ correct_steps (`int`): number of correction steps performed on a produced sample.
+ """
+
+ @property
+ def has_state(self):
+ return True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 2000,
+ snr: float = 0.15,
+ sigma_min: float = 0.01,
+ sigma_max: float = 1348.0,
+ sampling_eps: float = 1e-5,
+ correct_steps: int = 1,
+ ):
+ pass
+
+ def create_state(self):
+ state = ScoreSdeVeSchedulerState.create()
+ return self.set_sigmas(
+ state,
+ self.config.num_train_timesteps,
+ self.config.sigma_min,
+ self.config.sigma_max,
+ self.config.sampling_eps,
+ )
+
+ def set_timesteps(
+ self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
+ ) -> ScoreSdeVeSchedulerState:
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
+
+ """
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+
+ timesteps = jnp.linspace(1, sampling_eps, num_inference_steps)
+ return state.replace(timesteps=timesteps)
+
+ def set_sigmas(
+ self,
+ state: ScoreSdeVeSchedulerState,
+ num_inference_steps: int,
+ sigma_min: float = None,
+ sigma_max: float = None,
+ sampling_eps: float = None,
+ ) -> ScoreSdeVeSchedulerState:
+ """
+ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
+
+ The sigmas control the weight of the `drift` and `diffusion` components of sample update.
+
+ Args:
+ state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sigma_min (`float`, optional):
+ initial noise scale value (overrides value given at Scheduler instantiation).
+ sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
+ sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
+ """
+ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
+ sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+ if state.timesteps is None:
+ state = self.set_timesteps(state, num_inference_steps, sampling_eps)
+
+ discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps))
+ sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps])
+
+ return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas)
+
+ def get_adjacent_sigma(self, state, timesteps, t):
+ return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1])
+
+ def step_pred(
+ self,
+ state: ScoreSdeVeSchedulerState,
+ model_output: jnp.ndarray,
+ timestep: int,
+ sample: jnp.ndarray,
+ key: random.KeyArray,
+ return_dict: bool = True,
+ ) -> Union[FlaxSdeVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
+
+ Returns:
+ [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if state.timesteps is None:
+ raise ValueError(
+ "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ timestep = timestep * jnp.ones(
+ sample.shape[0],
+ )
+ timesteps = (timestep * (len(state.timesteps) - 1)).long()
+
+ sigma = state.discrete_sigmas[timesteps]
+ adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep)
+ drift = jnp.zeros_like(sample)
+ diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
+
+ # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
+ # also equation 47 shows the analog from SDE models to ancestral sampling methods
+ diffusion = diffusion.flatten()
+ diffusion = broadcast_to_shape_from_left(diffusion, sample.shape)
+ drift = drift - diffusion**2 * model_output
+
+ # equation 6: sample noise for the diffusion term of
+ key = random.split(key, num=1)
+ noise = random.normal(key=key, shape=sample.shape)
+ prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
+ # TODO is the variable diffusion the correct scaling term for the noise?
+ prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
+
+ if not return_dict:
+ return (prev_sample, prev_sample_mean, state)
+
+ return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state)
+
+ def step_correct(
+ self,
+ state: ScoreSdeVeSchedulerState,
+ model_output: jnp.ndarray,
+ sample: jnp.ndarray,
+ key: random.KeyArray,
+ return_dict: bool = True,
+ ) -> Union[FlaxSdeVeOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
+ after making the prediction for the previous timestep.
+
+ Args:
+ state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
+ sample (`jnp.ndarray`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class
+
+ Returns:
+ [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if state.timesteps is None:
+ raise ValueError(
+ "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
+ # sample noise for correction
+ key = random.split(key, num=1)
+ noise = random.normal(key=key, shape=sample.shape)
+
+ # compute step size from the model_output, the noise, and the snr
+ grad_norm = jnp.linalg.norm(model_output)
+ noise_norm = jnp.linalg.norm(noise)
+ step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
+ step_size = step_size * jnp.ones(sample.shape[0])
+
+ # compute corrected sample: model_output term and noise term
+ step_size = step_size.flatten()
+ step_size = broadcast_to_shape_from_left(step_size, sample.shape)
+ prev_sample_mean = sample + step_size * model_output
+ prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
+
+ if not return_dict:
+ return (prev_sample, state)
+
+ return FlaxSdeVeOutput(prev_sample=prev_sample, state=state)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_sde_vp.py b/diffusers/schedulers/scheduling_sde_vp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e4fe40229cfdb915aaca768fc484366ef6d60e1
--- /dev/null
+++ b/diffusers/schedulers/scheduling_sde_vp.py
@@ -0,0 +1,89 @@
+# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+import math
+from typing import Union
+
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin
+
+
+class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The variance preserving stochastic differential equation (SDE) scheduler.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more information, see the original paper: https://arxiv.org/abs/2011.13456
+
+ UNDER CONSTRUCTION
+
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
+ self.sigmas = None
+ self.discrete_sigmas = None
+ self.timesteps = None
+
+ def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None):
+ self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device)
+
+ def step_pred(self, score, x, t, generator=None):
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # TODO(Patrick) better comments + non-PyTorch
+ # postprocess model score
+ log_mean_coeff = (
+ -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
+ )
+ std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
+ std = std.flatten()
+ while len(std.shape) < len(score.shape):
+ std = std.unsqueeze(-1)
+ score = -score / std
+
+ # compute
+ dt = -1.0 / len(self.timesteps)
+
+ beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
+ beta_t = beta_t.flatten()
+ while len(beta_t.shape) < len(x.shape):
+ beta_t = beta_t.unsqueeze(-1)
+ drift = -0.5 * beta_t * x
+
+ diffusion = torch.sqrt(beta_t)
+ drift = drift - diffusion**2 * score
+ x_mean = x + drift * dt
+
+ # add noise
+ noise = torch.randn(x.shape, layout=x.layout, generator=generator).to(x.device)
+ x = x_mean + diffusion * math.sqrt(-dt) * noise
+
+ return x, x_mean
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_utils.py b/diffusers/schedulers/scheduling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..90ab674e38a40796dd1183ec0ef341159f8f62b4
--- /dev/null
+++ b/diffusers/schedulers/scheduling_utils.py
@@ -0,0 +1,154 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import os
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Union
+
+import torch
+
+from ..utils import BaseOutput
+
+
+SCHEDULER_CONFIG_NAME = "scheduler_config.json"
+
+
+@dataclass
+class SchedulerOutput(BaseOutput):
+ """
+ Base class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class SchedulerMixin:
+ """
+ Mixin containing common functions for the schedulers.
+
+ Class attributes:
+ - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
+ `from_config` can be used from a class different than the one used to save the config (should be overridden
+ by parent class).
+ """
+
+ config_name = SCHEDULER_CONFIG_NAME
+ _compatibles = []
+ has_compatibles = True
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Dict[str, Any] = None,
+ subfolder: Optional[str] = None,
+ return_unused_kwargs=False,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing the schedluer configurations saved using
+ [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
+ subfolder (`str`, *optional*):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+
+ """
+ config, kwargs = cls.load_config(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ subfolder=subfolder,
+ return_unused_kwargs=True,
+ **kwargs,
+ )
+ return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~SchedulerMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ """
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+
+ @property
+ def compatibles(self):
+ """
+ Returns all schedulers that are compatible with this scheduler
+
+ Returns:
+ `List[SchedulerMixin]`: List of compatible schedulers
+ """
+ return self._get_compatibles()
+
+ @classmethod
+ def _get_compatibles(cls):
+ compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
+ compatible_classes = [
+ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
+ ]
+ return compatible_classes
diff --git a/diffusers/schedulers/scheduling_utils_flax.py b/diffusers/schedulers/scheduling_utils_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dc28c25d9d6ec6c878b037cf4cd649e9b5a7b85
--- /dev/null
+++ b/diffusers/schedulers/scheduling_utils_flax.py
@@ -0,0 +1,169 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import os
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Tuple, Union
+
+import jax.numpy as jnp
+
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
+
+
+SCHEDULER_CONFIG_NAME = "scheduler_config.json"
+_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS]
+
+
+@dataclass
+class FlaxSchedulerOutput(BaseOutput):
+ """
+ Base class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: jnp.ndarray
+
+
+class FlaxSchedulerMixin:
+ """
+ Mixin containing common functions for the schedulers.
+
+ Class attributes:
+ - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
+ `from_config` can be used from a class different than the one used to save the config (should be overridden
+ by parent class).
+ """
+
+ config_name = SCHEDULER_CONFIG_NAME
+ _compatibles = []
+ has_compatibles = True
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Dict[str, Any] = None,
+ subfolder: Optional[str] = None,
+ return_unused_kwargs=False,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a Scheduler class from a pre-defined JSON-file.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`],
+ e.g., `./my_model_directory/`.
+ subfolder (`str`, *optional*):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+
+ """
+ config, kwargs = cls.load_config(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ subfolder=subfolder,
+ return_unused_kwargs=True,
+ **kwargs,
+ )
+ scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs)
+
+ if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False):
+ state = scheduler.create_state()
+
+ if return_unused_kwargs:
+ return scheduler, state, unused_kwargs
+
+ return scheduler, state
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~FlaxSchedulerMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ """
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+
+ @property
+ def compatibles(self):
+ """
+ Returns all schedulers that are compatible with this scheduler
+
+ Returns:
+ `List[SchedulerMixin]`: List of compatible schedulers
+ """
+ return self._get_compatibles()
+
+ @classmethod
+ def _get_compatibles(cls):
+ compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
+ compatible_classes = [
+ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
+ ]
+ return compatible_classes
+
+
+def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
+ assert len(shape) >= x.ndim
+ return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
diff --git a/diffusers/schedulers/scheduling_vq_diffusion.py b/diffusers/schedulers/scheduling_vq_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..89ba722a1852cbbac3bbd053effedbe97d370993
--- /dev/null
+++ b/diffusers/schedulers/scheduling_vq_diffusion.py
@@ -0,0 +1,496 @@
+# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+class VQDiffusionSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
+ Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.LongTensor
+
+
+def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor:
+ """
+ Convert batch of vector of class indices into batch of log onehot vectors
+
+ Args:
+ x (`torch.LongTensor` of shape `(batch size, vector length)`):
+ Batch of class indices
+
+ num_classes (`int`):
+ number of classes to be used for the onehot vectors
+
+ Returns:
+ `torch.FloatTensor` of shape `(batch size, num classes, vector length)`:
+ Log onehot vectors
+ """
+ x_onehot = F.one_hot(x, num_classes)
+ x_onehot = x_onehot.permute(0, 2, 1)
+ log_x = torch.log(x_onehot.float().clamp(min=1e-30))
+ return log_x
+
+
+def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor:
+ """
+ Apply gumbel noise to `logits`
+ """
+ uniform = torch.rand(logits.shape, device=logits.device, generator=generator)
+ gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
+ noised = gumbel_noise + logits
+ return noised
+
+
+def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009):
+ """
+ Cumulative and non-cumulative alpha schedules.
+
+ See section 4.1.
+ """
+ att = (
+ np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start)
+ + alpha_cum_start
+ )
+ att = np.concatenate(([1], att))
+ at = att[1:] / att[:-1]
+ att = np.concatenate((att[1:], [1]))
+ return at, att
+
+
+def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999):
+ """
+ Cumulative and non-cumulative gamma schedules.
+
+ See section 4.1.
+ """
+ ctt = (
+ np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start)
+ + gamma_cum_start
+ )
+ ctt = np.concatenate(([0], ctt))
+ one_minus_ctt = 1 - ctt
+ one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1]
+ ct = 1 - one_minus_ct
+ ctt = np.concatenate((ctt[1:], [0]))
+ return ct, ctt
+
+
+class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image.
+
+ The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous
+ diffusion timestep.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2111.14822
+
+ Args:
+ num_vec_classes (`int`):
+ The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked
+ latent pixel.
+
+ num_train_timesteps (`int`):
+ Number of diffusion steps used to train the model.
+
+ alpha_cum_start (`float`):
+ The starting cumulative alpha value.
+
+ alpha_cum_end (`float`):
+ The ending cumulative alpha value.
+
+ gamma_cum_start (`float`):
+ The starting cumulative gamma value.
+
+ gamma_cum_end (`float`):
+ The ending cumulative gamma value.
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_vec_classes: int,
+ num_train_timesteps: int = 100,
+ alpha_cum_start: float = 0.99999,
+ alpha_cum_end: float = 0.000009,
+ gamma_cum_start: float = 0.000009,
+ gamma_cum_end: float = 0.99999,
+ ):
+ self.num_embed = num_vec_classes
+
+ # By convention, the index for the mask class is the last class index
+ self.mask_class = self.num_embed - 1
+
+ at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end)
+ ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end)
+
+ num_non_mask_classes = self.num_embed - 1
+ bt = (1 - at - ct) / num_non_mask_classes
+ btt = (1 - att - ctt) / num_non_mask_classes
+
+ at = torch.tensor(at.astype("float64"))
+ bt = torch.tensor(bt.astype("float64"))
+ ct = torch.tensor(ct.astype("float64"))
+ log_at = torch.log(at)
+ log_bt = torch.log(bt)
+ log_ct = torch.log(ct)
+
+ att = torch.tensor(att.astype("float64"))
+ btt = torch.tensor(btt.astype("float64"))
+ ctt = torch.tensor(ctt.astype("float64"))
+ log_cumprod_at = torch.log(att)
+ log_cumprod_bt = torch.log(btt)
+ log_cumprod_ct = torch.log(ctt)
+
+ self.log_at = log_at.float()
+ self.log_bt = log_bt.float()
+ self.log_ct = log_ct.float()
+ self.log_cumprod_at = log_cumprod_at.float()
+ self.log_cumprod_bt = log_cumprod_bt.float()
+ self.log_cumprod_ct = log_cumprod_ct.float()
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+
+ device (`str` or `torch.device`):
+ device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on.
+ """
+ self.num_inference_steps = num_inference_steps
+ timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ self.log_at = self.log_at.to(device)
+ self.log_bt = self.log_bt.to(device)
+ self.log_ct = self.log_ct.to(device)
+ self.log_cumprod_at = self.log_cumprod_at.to(device)
+ self.log_cumprod_bt = self.log_cumprod_bt.to(device)
+ self.log_cumprod_ct = self.log_cumprod_ct.to(device)
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: torch.long,
+ sample: torch.LongTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[VQDiffusionSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the
+ docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed.
+
+ Args:
+ log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
+ The log probabilities for the predicted classes of the initial latent pixels. Does not include a
+ prediction for the masked class as the initial unnoised image cannot be masked.
+
+ t (`torch.long`):
+ The timestep that determines which transition matrices are used.
+
+ x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
+ The classes of each latent pixel at time `t`
+
+ generator: (`torch.Generator` or None):
+ RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from.
+
+ return_dict (`bool`):
+ option for returning tuple rather than VQDiffusionSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
+ When returning a tuple, the first element is the sample tensor.
+ """
+ if timestep == 0:
+ log_p_x_t_min_1 = model_output
+ else:
+ log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep)
+
+ log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator)
+
+ x_t_min_1 = log_p_x_t_min_1.argmax(dim=1)
+
+ if not return_dict:
+ return (x_t_min_1,)
+
+ return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1)
+
+ def q_posterior(self, log_p_x_0, x_t, t):
+ """
+ Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11).
+
+ Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only
+ forward probabilities.
+
+ Equation (11) stated in terms of forward probabilities via Equation (5):
+
+ Where:
+ - the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0)
+
+ p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) )
+
+ Args:
+ log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
+ The log probabilities for the predicted classes of the initial latent pixels. Does not include a
+ prediction for the masked class as the initial unnoised image cannot be masked.
+
+ x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
+ The classes of each latent pixel at time `t`
+
+ t (torch.Long):
+ The timestep that determines which transition matrix is used.
+
+ Returns:
+ `torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`:
+ The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11).
+ """
+ log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed)
+
+ log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class(
+ t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True
+ )
+
+ log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class(
+ t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False
+ )
+
+ # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0)
+ # . . .
+ # . . .
+ # . . .
+ # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1})
+ q = log_p_x_0 - log_q_x_t_given_x_0
+
+ # sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... ,
+ # sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1})
+ q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True)
+
+ # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n
+ # . . .
+ # . . .
+ # . . .
+ # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n
+ q = q - q_log_sum_exp
+
+ # (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}
+ # . . .
+ # . . .
+ # . . .
+ # (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}
+ # c_cumulative_{t-1} ... c_cumulative_{t-1}
+ q = self.apply_cumulative_transitions(q, t - 1)
+
+ # ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n
+ # . . .
+ # . . .
+ # . . .
+ # ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n
+ # c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0
+ log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp
+
+ # For each column, there are two possible cases.
+ #
+ # Where:
+ # - sum(p_n(x_0))) is summing over all classes for x_0
+ # - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's)
+ # - C_j is the class transitioning to
+ #
+ # 1. x_t is masked i.e. x_t = c_k
+ #
+ # Simplifying the expression, the column vector is:
+ # .
+ # .
+ # .
+ # (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0)))
+ # .
+ # .
+ # .
+ # (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0))
+ #
+ # From equation (11) stated in terms of forward probabilities, the last row is trivially verified.
+ #
+ # For the other rows, we can state the equation as ...
+ #
+ # (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})]
+ #
+ # This verifies the other rows.
+ #
+ # 2. x_t is not masked
+ #
+ # Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i:
+ # .
+ # .
+ # .
+ # C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1}))
+ # .
+ # .
+ # .
+ # C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1}))
+ # .
+ # .
+ # .
+ # 0
+ #
+ # The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities.
+ return log_p_x_t_min_1
+
+ def log_Q_t_transitioning_to_known_class(
+ self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool
+ ):
+ """
+ Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each
+ latent pixel in `x_t`.
+
+ See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix
+ is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs.
+
+ Args:
+ t (torch.Long):
+ The timestep that determines which transition matrix is used.
+
+ x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
+ The classes of each latent pixel at time `t`.
+
+ log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`):
+ The log one-hot vectors of `x_t`
+
+ cumulative (`bool`):
+ If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`,
+ we use the cumulative transition matrix `0`->`t`.
+
+ Returns:
+ `torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`:
+ Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability
+ transition matrix.
+
+ When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be
+ masked.
+
+ Where:
+ - `q_n` is the probability distribution for the forward process of the `n`th latent pixel.
+ - C_0 is a class of a latent pixel embedding
+ - C_k is the class of the masked latent pixel
+
+ non-cumulative result (omitting logarithms):
+ ```
+ q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0)
+ . . .
+ . . .
+ . . .
+ q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k)
+ ```
+
+ cumulative result (omitting logarithms):
+ ```
+ q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0)
+ . . .
+ . . .
+ . . .
+ q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1})
+ ```
+ """
+ if cumulative:
+ a = self.log_cumprod_at[t]
+ b = self.log_cumprod_bt[t]
+ c = self.log_cumprod_ct[t]
+ else:
+ a = self.log_at[t]
+ b = self.log_bt[t]
+ c = self.log_ct[t]
+
+ if not cumulative:
+ # The values in the onehot vector can also be used as the logprobs for transitioning
+ # from masked latent pixels. If we are not calculating the cumulative transitions,
+ # we need to save these vectors to be re-appended to the final matrix so the values
+ # aren't overwritten.
+ #
+ # `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector
+ # if x_t is not masked
+ #
+ # `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector
+ # if x_t is masked
+ log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1)
+
+ # `index_to_log_onehot` will add onehot vectors for masked pixels,
+ # so the default one hot matrix has one too many rows. See the doc string
+ # for an explanation of the dimensionality of the returned matrix.
+ log_onehot_x_t = log_onehot_x_t[:, :-1, :]
+
+ # this is a cheeky trick to produce the transition probabilities using log one-hot vectors.
+ #
+ # Don't worry about what values this sets in the columns that mark transitions
+ # to masked latent pixels. They are overwrote later with the `mask_class_mask`.
+ #
+ # Looking at the below logspace formula in non-logspace, each value will evaluate to either
+ # `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column
+ # or
+ # `0 * a + b = b` where `log_Q_t` has the 0 values in the column.
+ #
+ # See equation 7 for more details.
+ log_Q_t = (log_onehot_x_t + a).logaddexp(b)
+
+ # The whole column of each masked pixel is `c`
+ mask_class_mask = x_t == self.mask_class
+ mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1)
+ log_Q_t[mask_class_mask] = c
+
+ if not cumulative:
+ log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1)
+
+ return log_Q_t
+
+ def apply_cumulative_transitions(self, q, t):
+ bsz = q.shape[0]
+ a = self.log_cumprod_at[t]
+ b = self.log_cumprod_bt[t]
+ c = self.log_cumprod_ct[t]
+
+ num_latent_pixels = q.shape[2]
+ c = c.expand(bsz, 1, num_latent_pixels)
+
+ q = (q + a).logaddexp(b)
+ q = torch.cat((q, c), dim=1)
+
+ return q
diff --git a/diffusers/training_utils.py b/diffusers/training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa1694161fc54c7fd097abf3bcbf44c498daad4b
--- /dev/null
+++ b/diffusers/training_utils.py
@@ -0,0 +1,125 @@
+import copy
+import os
+import random
+
+import numpy as np
+import torch
+
+
+def enable_full_determinism(seed: int):
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ """
+ # set seed first
+ set_seed(seed)
+
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ torch.use_deterministic_algorithms(True)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def set_seed(seed: int):
+ """
+ Args:
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
+ seed (`int`): The seed to set.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # ^^ safe to call this function even if cuda is not available
+
+
+class EMAModel:
+ """
+ Exponential Moving Average of models weights
+ """
+
+ def __init__(
+ self,
+ model,
+ update_after_step=0,
+ inv_gamma=1.0,
+ power=2 / 3,
+ min_value=0.0,
+ max_value=0.9999,
+ device=None,
+ ):
+ """
+ @crowsonkb's notes on EMA Warmup:
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
+ at 215.4k steps).
+ Args:
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
+ power (float): Exponential factor of EMA warmup. Default: 2/3.
+ min_value (float): The minimum EMA decay rate. Default: 0.
+ """
+
+ self.averaged_model = copy.deepcopy(model).eval()
+ self.averaged_model.requires_grad_(False)
+
+ self.update_after_step = update_after_step
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.min_value = min_value
+ self.max_value = max_value
+
+ if device is not None:
+ self.averaged_model = self.averaged_model.to(device=device)
+
+ self.decay = 0.0
+ self.optimization_step = 0
+
+ def get_decay(self, optimization_step):
+ """
+ Compute the decay factor for the exponential moving average.
+ """
+ step = max(0, optimization_step - self.update_after_step - 1)
+ value = 1 - (1 + step / self.inv_gamma) ** -self.power
+
+ if step <= 0:
+ return 0.0
+
+ return max(self.min_value, min(value, self.max_value))
+
+ @torch.no_grad()
+ def step(self, new_model):
+ ema_state_dict = {}
+ ema_params = self.averaged_model.state_dict()
+
+ self.decay = self.get_decay(self.optimization_step)
+
+ for key, param in new_model.named_parameters():
+ if isinstance(param, dict):
+ continue
+ try:
+ ema_param = ema_params[key]
+ except KeyError:
+ ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
+ ema_params[key] = ema_param
+
+ if not param.requires_grad:
+ ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
+ ema_param = ema_params[key]
+ else:
+ ema_param.mul_(self.decay)
+ ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
+
+ ema_state_dict[key] = ema_param
+
+ for key, param in new_model.named_buffers():
+ ema_state_dict[key] = param
+
+ self.averaged_model.load_state_dict(ema_state_dict, strict=False)
+ self.optimization_step += 1
diff --git a/diffusers/utils/__init__.py b/diffusers/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c2e2c9abbc61f01e2476538e3eb342803880502
--- /dev/null
+++ b/diffusers/utils/__init__.py
@@ -0,0 +1,89 @@
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+from .deprecation_utils import deprecate
+from .import_utils import (
+ ENV_VARS_TRUE_AND_AUTO_VALUES,
+ ENV_VARS_TRUE_VALUES,
+ USE_JAX,
+ USE_TF,
+ USE_TORCH,
+ DummyObject,
+ is_accelerate_available,
+ is_flax_available,
+ is_inflect_available,
+ is_modelcards_available,
+ is_onnx_available,
+ is_safetensors_available,
+ is_scipy_available,
+ is_tf_available,
+ is_torch_available,
+ is_torch_version,
+ is_transformers_available,
+ is_transformers_version,
+ is_unidecode_available,
+ requires_backends,
+)
+from .logging import get_logger
+from .outputs import BaseOutput
+from .pil_utils import PIL_INTERPOLATION
+
+
+if is_torch_available():
+ from .testing_utils import (
+ floats_tensor,
+ load_hf_numpy,
+ load_image,
+ load_numpy,
+ parse_flag_from_env,
+ require_torch_gpu,
+ slow,
+ torch_all_close,
+ torch_device,
+ )
+
+
+logger = get_logger(__name__)
+
+
+hf_cache_home = os.path.expanduser(
+ os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
+)
+default_cache_path = os.path.join(hf_cache_home, "diffusers")
+
+
+CONFIG_NAME = "config.json"
+WEIGHTS_NAME = "diffusion_pytorch_model.bin"
+FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
+ONNX_WEIGHTS_NAME = "model.onnx"
+SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
+ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
+HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
+DIFFUSERS_CACHE = default_cache_path
+DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
+HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
+
+_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
+ "DDIMScheduler",
+ "DDPMScheduler",
+ "PNDMScheduler",
+ "LMSDiscreteScheduler",
+ "EulerDiscreteScheduler",
+ "HeunDiscreteScheduler",
+ "EulerAncestralDiscreteScheduler",
+ "DPMSolverMultistepScheduler",
+]
diff --git a/diffusers/utils/deprecation_utils.py b/diffusers/utils/deprecation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bdda664e102ea9913503b9e169fa97225d52c78
--- /dev/null
+++ b/diffusers/utils/deprecation_utils.py
@@ -0,0 +1,49 @@
+import inspect
+import warnings
+from typing import Any, Dict, Optional, Union
+
+from packaging import version
+
+
+def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True):
+ from .. import __version__
+
+ deprecated_kwargs = take_from
+ values = ()
+ if not isinstance(args[0], tuple):
+ args = (args,)
+
+ for attribute, version_name, message in args:
+ if version.parse(version.parse(__version__).base_version) >= version.parse(version_name):
+ raise ValueError(
+ f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'"
+ f" version {__version__} is >= {version_name}"
+ )
+
+ warning = None
+ if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs:
+ values += (deprecated_kwargs.pop(attribute),)
+ warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}."
+ elif hasattr(deprecated_kwargs, attribute):
+ values += (getattr(deprecated_kwargs, attribute),)
+ warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}."
+ elif deprecated_kwargs is None:
+ warning = f"`{attribute}` is deprecated and will be removed in version {version_name}."
+
+ if warning is not None:
+ warning = warning + " " if standard_warn else ""
+ warnings.warn(warning + message, FutureWarning, stacklevel=2)
+
+ if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
+ call_frame = inspect.getouterframes(inspect.currentframe())[1]
+ filename = call_frame.filename
+ line_number = call_frame.lineno
+ function = call_frame.function
+ key, value = next(iter(deprecated_kwargs.items()))
+ raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
+
+ if len(values) == 0:
+ return
+ elif len(values) == 1:
+ return values[0]
+ return values
diff --git a/diffusers/utils/dummy_flax_and_transformers_objects.py b/diffusers/utils/dummy_flax_and_transformers_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..14830bca2898ed550eb9a0b671282a81967c8570
--- /dev/null
+++ b/diffusers/utils/dummy_flax_and_transformers_objects.py
@@ -0,0 +1,19 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class FlaxStableDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["flax", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax", "transformers"])
diff --git a/diffusers/utils/dummy_flax_objects.py b/diffusers/utils/dummy_flax_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e308bb41bea681993049d8a5ec3ff22987d5d14
--- /dev/null
+++ b/diffusers/utils/dummy_flax_objects.py
@@ -0,0 +1,184 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class FlaxModelMixin(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxUNet2DConditionModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxAutoencoderKL(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxDDIMScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxDDPMScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxKarrasVeScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxPNDMScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxSchedulerMixin(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+
+class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["flax"])
diff --git a/diffusers/utils/dummy_pt_objects.py b/diffusers/utils/dummy_pt_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..23afb51cf30c0273507d296a47e96da087ea5f2d
--- /dev/null
+++ b/diffusers/utils/dummy_pt_objects.py
@@ -0,0 +1,527 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class ModelMixin(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKL(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class Transformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class UNet1DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class UNet2DConditionModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class UNet2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class VQModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+def get_constant_schedule(*args, **kwargs):
+ requires_backends(get_constant_schedule, ["torch"])
+
+
+def get_constant_schedule_with_warmup(*args, **kwargs):
+ requires_backends(get_constant_schedule_with_warmup, ["torch"])
+
+
+def get_cosine_schedule_with_warmup(*args, **kwargs):
+ requires_backends(get_cosine_schedule_with_warmup, ["torch"])
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
+ requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"])
+
+
+def get_linear_schedule_with_warmup(*args, **kwargs):
+ requires_backends(get_linear_schedule_with_warmup, ["torch"])
+
+
+def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
+ requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"])
+
+
+def get_scheduler(*args, **kwargs):
+ requires_backends(get_scheduler, ["torch"])
+
+
+class DiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DanceDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DDIMPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DDPMPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class KarrasVePipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class LDMPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class LDMSuperResolutionPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class PNDMPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class RePaintPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ScoreSdeVePipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DDIMScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DDPMScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class DPMSolverMultistepScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class EulerDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class HeunDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class IPNDMScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class KarrasVeScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class KDPM2AncestralDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class KDPM2DiscreteScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class PNDMScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class RePaintScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class SchedulerMixin(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ScoreSdeVeScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class VQDiffusionScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class EMAModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
diff --git a/diffusers/utils/dummy_torch_and_scipy_objects.py b/diffusers/utils/dummy_torch_and_scipy_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..13f17349bb4504f99251eeb75ae57e047045e85e
--- /dev/null
+++ b/diffusers/utils/dummy_torch_and_scipy_objects.py
@@ -0,0 +1,19 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class LMSDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["torch", "scipy"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "scipy"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "scipy"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "scipy"])
diff --git a/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py b/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae9412a9568202bb8ae39ea2a07cc26208cf7aa8
--- /dev/null
+++ b/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py
@@ -0,0 +1,79 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+
+class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+
+class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+
+class OnnxStableDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+
+class StableDiffusionOnnxPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "onnx"])
diff --git a/diffusers/utils/dummy_torch_and_transformers_objects.py b/diffusers/utils/dummy_torch_and_transformers_objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d932d240508e138b2d30328bd4c94655b4498ba
--- /dev/null
+++ b/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -0,0 +1,244 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AltDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CycleDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class LDMTextToImagePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionImageVariationPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionPipelineSafe(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionUpscalePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class VersatileDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class VQDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
diff --git a/diffusers/utils/import_utils.py b/diffusers/utils/import_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..531f9eab2f7ae32f818c990ea905f8c5bb98b861
--- /dev/null
+++ b/diffusers/utils/import_utils.py
@@ -0,0 +1,396 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Import utilities: Utilities related to imports and our lazy inits.
+"""
+import importlib.util
+import operator as op
+import os
+import sys
+from collections import OrderedDict
+from typing import Union
+
+from packaging import version
+from packaging.version import Version, parse
+
+from . import logging
+
+
+# The package importlib_metadata is in a different place, depending on the python version.
+if sys.version_info < (3, 8):
+ import importlib_metadata
+else:
+ import importlib.metadata as importlib_metadata
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
+ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
+
+USE_TF = os.environ.get("USE_TF", "AUTO").upper()
+USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
+USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
+USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()
+
+STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
+
+_torch_version = "N/A"
+if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
+ _torch_available = importlib.util.find_spec("torch") is not None
+ if _torch_available:
+ try:
+ _torch_version = importlib_metadata.version("torch")
+ logger.info(f"PyTorch version {_torch_version} available.")
+ except importlib_metadata.PackageNotFoundError:
+ _torch_available = False
+else:
+ logger.info("Disabling PyTorch because USE_TORCH is set")
+ _torch_available = False
+
+
+_tf_version = "N/A"
+if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
+ _tf_available = importlib.util.find_spec("tensorflow") is not None
+ if _tf_available:
+ candidates = (
+ "tensorflow",
+ "tensorflow-cpu",
+ "tensorflow-gpu",
+ "tf-nightly",
+ "tf-nightly-cpu",
+ "tf-nightly-gpu",
+ "intel-tensorflow",
+ "intel-tensorflow-avx512",
+ "tensorflow-rocm",
+ "tensorflow-macos",
+ "tensorflow-aarch64",
+ )
+ _tf_version = None
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
+ for pkg in candidates:
+ try:
+ _tf_version = importlib_metadata.version(pkg)
+ break
+ except importlib_metadata.PackageNotFoundError:
+ pass
+ _tf_available = _tf_version is not None
+ if _tf_available:
+ if version.parse(_tf_version) < version.parse("2"):
+ logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.")
+ _tf_available = False
+ else:
+ logger.info(f"TensorFlow version {_tf_version} available.")
+else:
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
+ _tf_available = False
+
+_jax_version = "N/A"
+_flax_version = "N/A"
+if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
+ _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
+ if _flax_available:
+ try:
+ _jax_version = importlib_metadata.version("jax")
+ _flax_version = importlib_metadata.version("flax")
+ logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
+ except importlib_metadata.PackageNotFoundError:
+ _flax_available = False
+else:
+ _flax_available = False
+
+if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
+ _safetensors_available = importlib.util.find_spec("safetensors") is not None
+ if _safetensors_available:
+ try:
+ _safetensors_version = importlib_metadata.version("safetensors")
+ logger.info(f"Safetensors version {_safetensors_version} available.")
+ except importlib_metadata.PackageNotFoundError:
+ _safetensors_available = False
+else:
+ logger.info("Disabling Safetensors because USE_TF is set")
+ _safetensors_available = False
+
+_transformers_available = importlib.util.find_spec("transformers") is not None
+try:
+ _transformers_version = importlib_metadata.version("transformers")
+ logger.debug(f"Successfully imported transformers version {_transformers_version}")
+except importlib_metadata.PackageNotFoundError:
+ _transformers_available = False
+
+
+_inflect_available = importlib.util.find_spec("inflect") is not None
+try:
+ _inflect_version = importlib_metadata.version("inflect")
+ logger.debug(f"Successfully imported inflect version {_inflect_version}")
+except importlib_metadata.PackageNotFoundError:
+ _inflect_available = False
+
+
+_unidecode_available = importlib.util.find_spec("unidecode") is not None
+try:
+ _unidecode_version = importlib_metadata.version("unidecode")
+ logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
+except importlib_metadata.PackageNotFoundError:
+ _unidecode_available = False
+
+
+_modelcards_available = importlib.util.find_spec("modelcards") is not None
+try:
+ _modelcards_version = importlib_metadata.version("modelcards")
+ logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
+except importlib_metadata.PackageNotFoundError:
+ _modelcards_available = False
+
+
+_onnxruntime_version = "N/A"
+_onnx_available = importlib.util.find_spec("onnxruntime") is not None
+if _onnx_available:
+ candidates = (
+ "onnxruntime",
+ "onnxruntime-gpu",
+ "onnxruntime-directml",
+ "onnxruntime-openvino",
+ "ort_nightly_directml",
+ )
+ _onnxruntime_version = None
+ # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu
+ for pkg in candidates:
+ try:
+ _onnxruntime_version = importlib_metadata.version(pkg)
+ break
+ except importlib_metadata.PackageNotFoundError:
+ pass
+ _onnx_available = _onnxruntime_version is not None
+ if _onnx_available:
+ logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
+
+
+_scipy_available = importlib.util.find_spec("scipy") is not None
+try:
+ _scipy_version = importlib_metadata.version("scipy")
+ logger.debug(f"Successfully imported transformers version {_scipy_version}")
+except importlib_metadata.PackageNotFoundError:
+ _scipy_available = False
+
+_accelerate_available = importlib.util.find_spec("accelerate") is not None
+try:
+ _accelerate_version = importlib_metadata.version("accelerate")
+ logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
+except importlib_metadata.PackageNotFoundError:
+ _accelerate_available = False
+
+_xformers_available = importlib.util.find_spec("xformers") is not None
+try:
+ _xformers_version = importlib_metadata.version("xformers")
+ if _torch_available:
+ import torch
+
+ if torch.__version__ < version.Version("1.12"):
+ raise ValueError("PyTorch should be >= 1.12")
+ logger.debug(f"Successfully imported xformers version {_xformers_version}")
+except importlib_metadata.PackageNotFoundError:
+ _xformers_available = False
+
+
+def is_torch_available():
+ return _torch_available
+
+
+def is_safetensors_available():
+ return _safetensors_available
+
+
+def is_tf_available():
+ return _tf_available
+
+
+def is_flax_available():
+ return _flax_available
+
+
+def is_transformers_available():
+ return _transformers_available
+
+
+def is_inflect_available():
+ return _inflect_available
+
+
+def is_unidecode_available():
+ return _unidecode_available
+
+
+def is_modelcards_available():
+ return _modelcards_available
+
+
+def is_onnx_available():
+ return _onnx_available
+
+
+def is_scipy_available():
+ return _scipy_available
+
+
+def is_xformers_available():
+ return _xformers_available
+
+
+def is_accelerate_available():
+ return _accelerate_available
+
+
+# docstyle-ignore
+FLAX_IMPORT_ERROR = """
+{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
+installation page: https://github.com/google/flax and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+INFLECT_IMPORT_ERROR = """
+{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
+inflect`
+"""
+
+# docstyle-ignore
+PYTORCH_IMPORT_ERROR = """
+{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
+installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+ONNX_IMPORT_ERROR = """
+{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip
+install onnxruntime`
+"""
+
+# docstyle-ignore
+SCIPY_IMPORT_ERROR = """
+{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
+scipy`
+"""
+
+# docstyle-ignore
+TENSORFLOW_IMPORT_ERROR = """
+{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
+installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+TRANSFORMERS_IMPORT_ERROR = """
+{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
+install transformers`
+"""
+
+# docstyle-ignore
+UNIDECODE_IMPORT_ERROR = """
+{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
+Unidecode`
+"""
+
+
+BACKENDS_MAPPING = OrderedDict(
+ [
+ ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
+ ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
+ ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
+ ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
+ ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
+ ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
+ ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
+ ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
+ ]
+)
+
+
+def requires_backends(obj, backends):
+ if not isinstance(backends, (list, tuple)):
+ backends = [backends]
+
+ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
+ checks = (BACKENDS_MAPPING[backend] for backend in backends)
+ failed = [msg.format(name) for available, msg in checks if not available()]
+ if failed:
+ raise ImportError("".join(failed))
+
+ if name in [
+ "VersatileDiffusionTextToImagePipeline",
+ "VersatileDiffusionPipeline",
+ "VersatileDiffusionDualGuidedPipeline",
+ "StableDiffusionImageVariationPipeline",
+ ] and is_transformers_version("<", "4.25.0.dev0"):
+ raise ImportError(
+ f"You need to install `transformers` from 'main' in order to use {name}: \n```\n pip install"
+ " git+https://github.com/huggingface/transformers \n```"
+ )
+
+
+class DummyObject(type):
+ """
+ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
+ `requires_backend` each time a user tries to access any method of that class.
+ """
+
+ def __getattr__(cls, key):
+ if key.startswith("_"):
+ return super().__getattr__(cls, key)
+ requires_backends(cls, cls._backends)
+
+
+# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
+def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
+ """
+ Args:
+ Compares a library version to some requirement using a given operation.
+ library_or_version (`str` or `packaging.version.Version`):
+ A library name or a version to check.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`.
+ requirement_version (`str`):
+ The version to compare the library version against
+ """
+ if operation not in STR_OPERATION_TO_FUNC.keys():
+ raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
+ operation = STR_OPERATION_TO_FUNC[operation]
+ if isinstance(library_or_version, str):
+ library_or_version = parse(importlib_metadata.version(library_or_version))
+ return operation(library_or_version, parse(requirement_version))
+
+
+# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
+def is_torch_version(operation: str, version: str):
+ """
+ Args:
+ Compares the current PyTorch version to a given reference with an operation.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A string version of PyTorch
+ """
+ return compare_versions(parse(_torch_version), operation, version)
+
+
+def is_transformers_version(operation: str, version: str):
+ """
+ Args:
+ Compares the current Transformers version to a given reference with an operation.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A string version of PyTorch
+ """
+ if not _transformers_available:
+ return False
+ return compare_versions(parse(_transformers_version), operation, version)
diff --git a/diffusers/utils/logging.py b/diffusers/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c1c77d10b2a6b06a0c57d4fdf1802e3bd5f705f
--- /dev/null
+++ b/diffusers/utils/logging.py
@@ -0,0 +1,340 @@
+# coding=utf-8
+# Copyright 2020 Optuna, Hugging Face
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Logging utilities."""
+
+import logging
+import os
+import sys
+import threading
+from logging import CRITICAL # NOQA
+from logging import DEBUG # NOQA
+from logging import ERROR # NOQA
+from logging import FATAL # NOQA
+from logging import INFO # NOQA
+from logging import NOTSET # NOQA
+from logging import WARN # NOQA
+from logging import WARNING # NOQA
+from typing import Optional
+
+from tqdm import auto as tqdm_lib
+
+
+_lock = threading.Lock()
+_default_handler: Optional[logging.Handler] = None
+
+log_levels = {
+ "debug": logging.DEBUG,
+ "info": logging.INFO,
+ "warning": logging.WARNING,
+ "error": logging.ERROR,
+ "critical": logging.CRITICAL,
+}
+
+_default_log_level = logging.WARNING
+
+_tqdm_active = True
+
+
+def _get_default_logging_level():
+ """
+ If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
+ not - fall back to `_default_log_level`
+ """
+ env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
+ if env_level_str:
+ if env_level_str in log_levels:
+ return log_levels[env_level_str]
+ else:
+ logging.getLogger().warning(
+ f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
+ f"has to be one of: { ', '.join(log_levels.keys()) }"
+ )
+ return _default_log_level
+
+
+def _get_library_name() -> str:
+ return __name__.split(".")[0]
+
+
+def _get_library_root_logger() -> logging.Logger:
+ return logging.getLogger(_get_library_name())
+
+
+def _configure_library_root_logger() -> None:
+ global _default_handler
+
+ with _lock:
+ if _default_handler:
+ # This library has already configured the library root logger.
+ return
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
+ _default_handler.flush = sys.stderr.flush
+
+ # Apply our default configuration to the library root logger.
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.addHandler(_default_handler)
+ library_root_logger.setLevel(_get_default_logging_level())
+ library_root_logger.propagate = False
+
+
+def _reset_library_root_logger() -> None:
+ global _default_handler
+
+ with _lock:
+ if not _default_handler:
+ return
+
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.removeHandler(_default_handler)
+ library_root_logger.setLevel(logging.NOTSET)
+ _default_handler = None
+
+
+def get_log_levels_dict():
+ return log_levels
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+ """
+ Return a logger with the specified name.
+
+ This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
+ """
+
+ if name is None:
+ name = _get_library_name()
+
+ _configure_library_root_logger()
+ return logging.getLogger(name)
+
+
+def get_verbosity() -> int:
+ """
+ Return the current level for the 🤗 Diffusers' root logger as an int.
+
+ Returns:
+ `int`: The logging level.
+
+
+
+ 🤗 Diffusers has following logging levels:
+
+ - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
+ - 40: `diffusers.logging.ERROR`
+ - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
+ - 20: `diffusers.logging.INFO`
+ - 10: `diffusers.logging.DEBUG`
+
+ """
+
+ _configure_library_root_logger()
+ return _get_library_root_logger().getEffectiveLevel()
+
+
+def set_verbosity(verbosity: int) -> None:
+ """
+ Set the verbosity level for the 🤗 Diffusers' root logger.
+
+ Args:
+ verbosity (`int`):
+ Logging level, e.g., one of:
+
+ - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
+ - `diffusers.logging.ERROR`
+ - `diffusers.logging.WARNING` or `diffusers.logging.WARN`
+ - `diffusers.logging.INFO`
+ - `diffusers.logging.DEBUG`
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().setLevel(verbosity)
+
+
+def set_verbosity_info():
+ """Set the verbosity to the `INFO` level."""
+ return set_verbosity(INFO)
+
+
+def set_verbosity_warning():
+ """Set the verbosity to the `WARNING` level."""
+ return set_verbosity(WARNING)
+
+
+def set_verbosity_debug():
+ """Set the verbosity to the `DEBUG` level."""
+ return set_verbosity(DEBUG)
+
+
+def set_verbosity_error():
+ """Set the verbosity to the `ERROR` level."""
+ return set_verbosity(ERROR)
+
+
+def disable_default_handler() -> None:
+ """Disable the default handler of the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().removeHandler(_default_handler)
+
+
+def enable_default_handler() -> None:
+ """Enable the default handler of the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().addHandler(_default_handler)
+
+
+def add_handler(handler: logging.Handler) -> None:
+ """adds a handler to the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None
+ _get_library_root_logger().addHandler(handler)
+
+
+def remove_handler(handler: logging.Handler) -> None:
+ """removes given handler from the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None and handler not in _get_library_root_logger().handlers
+ _get_library_root_logger().removeHandler(handler)
+
+
+def disable_propagation() -> None:
+ """
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = False
+
+
+def enable_propagation() -> None:
+ """
+ Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
+ double logging if the root logger has been configured.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = True
+
+
+def enable_explicit_format() -> None:
+ """
+ Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows:
+ ```
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
+ ```
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
+ handler.setFormatter(formatter)
+
+
+def reset_format() -> None:
+ """
+ Resets the formatting for HuggingFace Diffusers' loggers.
+
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ handler.setFormatter(None)
+
+
+def warning_advice(self, *args, **kwargs):
+ """
+ This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
+ warning will not be printed
+ """
+ no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
+ if no_advisory_warnings:
+ return
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_advice = warning_advice
+
+
+class EmptyTqdm:
+ """Dummy tqdm which doesn't do anything."""
+
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
+ self._iterator = args[0] if args else None
+
+ def __iter__(self):
+ return iter(self._iterator)
+
+ def __getattr__(self, _):
+ """Return empty function."""
+
+ def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
+ return
+
+ return empty_fn
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ return
+
+
+class _tqdm_cls:
+ def __call__(self, *args, **kwargs):
+ if _tqdm_active:
+ return tqdm_lib.tqdm(*args, **kwargs)
+ else:
+ return EmptyTqdm(*args, **kwargs)
+
+ def set_lock(self, *args, **kwargs):
+ self._lock = None
+ if _tqdm_active:
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
+
+ def get_lock(self):
+ if _tqdm_active:
+ return tqdm_lib.tqdm.get_lock()
+
+
+tqdm = _tqdm_cls()
+
+
+def is_progress_bar_enabled() -> bool:
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
+ global _tqdm_active
+ return bool(_tqdm_active)
+
+
+def enable_progress_bar():
+ """Enable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = True
+
+
+def disable_progress_bar():
+ """Disable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = False
diff --git a/diffusers/utils/model_card_template.md b/diffusers/utils/model_card_template.md
new file mode 100644
index 0000000000000000000000000000000000000000..f19c85b0fcf2f7b07e9c3f950a9657b3f2053f21
--- /dev/null
+++ b/diffusers/utils/model_card_template.md
@@ -0,0 +1,50 @@
+---
+{{ card_data }}
+---
+
+
+
+# {{ model_name | default("Diffusion Model") }}
+
+## Model description
+
+This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library
+on the `{{ dataset_name }}` dataset.
+
+## Intended uses & limitations
+
+#### How to use
+
+```python
+# TODO: add an example code snippet for running this diffusion pipeline
+```
+
+#### Limitations and bias
+
+[TODO: provide examples of latent issues and potential remediations]
+
+## Training data
+
+[TODO: describe the data used to train the model]
+
+### Training hyperparameters
+
+The following hyperparameters were used during training:
+- learning_rate: {{ learning_rate }}
+- train_batch_size: {{ train_batch_size }}
+- eval_batch_size: {{ eval_batch_size }}
+- gradient_accumulation_steps: {{ gradient_accumulation_steps }}
+- optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }}
+- lr_scheduler: {{ lr_scheduler }}
+- lr_warmup_steps: {{ lr_warmup_steps }}
+- ema_inv_gamma: {{ ema_inv_gamma }}
+- ema_inv_gamma: {{ ema_power }}
+- ema_inv_gamma: {{ ema_max_decay }}
+- mixed_precision: {{ mixed_precision }}
+
+### Training results
+
+📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars)
+
+
diff --git a/diffusers/utils/outputs.py b/diffusers/utils/outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d902dd394ccddc408d85b48e4142facc7242550
--- /dev/null
+++ b/diffusers/utils/outputs.py
@@ -0,0 +1,108 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Generic utilities
+"""
+
+from collections import OrderedDict
+from dataclasses import fields
+from typing import Any, Tuple
+
+import numpy as np
+
+from .import_utils import is_torch_available
+
+
+def is_tensor(x):
+ """
+ Tests if `x` is a `torch.Tensor` or `np.ndarray`.
+ """
+ if is_torch_available():
+ import torch
+
+ if isinstance(x, torch.Tensor):
+ return True
+
+ return isinstance(x, np.ndarray)
+
+
+class BaseOutput(OrderedDict):
+ """
+ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
+ tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
+ python dictionary.
+
+
+
+ You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
+ before.
+
+
+ """
+
+ def __post_init__(self):
+ class_fields = fields(self)
+
+ # Safety and consistency checks
+ if not len(class_fields):
+ raise ValueError(f"{self.__class__.__name__} has no fields.")
+
+ first_field = getattr(self, class_fields[0].name)
+ other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
+
+ if other_fields_are_none and isinstance(first_field, dict):
+ for key, value in first_field.items():
+ self[key] = value
+ else:
+ for field in class_fields:
+ v = getattr(self, field.name)
+ if v is not None:
+ self[field.name] = v
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __getitem__(self, k):
+ if isinstance(k, str):
+ inner_dict = {k: v for (k, v) in self.items()}
+ return inner_dict[k]
+ else:
+ return self.to_tuple()[k]
+
+ def __setattr__(self, name, value):
+ if name in self.keys() and value is not None:
+ # Don't call self.__setitem__ to avoid recursion errors
+ super().__setitem__(name, value)
+ super().__setattr__(name, value)
+
+ def __setitem__(self, key, value):
+ # Will raise a KeyException if needed
+ super().__setitem__(key, value)
+ # Don't call self.__setattr__ to avoid recursion errors
+ super().__setattr__(key, value)
+
+ def to_tuple(self) -> Tuple[Any]:
+ """
+ Convert self to a tuple containing all the attributes/keys that are not `None`.
+ """
+ return tuple(self[k] for k in self.keys())
diff --git a/diffusers/utils/pil_utils.py b/diffusers/utils/pil_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..39d0a15a4e2fe39fecb01951b36c43368492f983
--- /dev/null
+++ b/diffusers/utils/pil_utils.py
@@ -0,0 +1,21 @@
+import PIL.Image
+import PIL.ImageOps
+from packaging import version
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
diff --git a/diffusers/utils/testing_utils.py b/diffusers/utils/testing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf398e5b6fe5b1b2c5a909bcd43a9fd772d250af
--- /dev/null
+++ b/diffusers/utils/testing_utils.py
@@ -0,0 +1,393 @@
+import inspect
+import logging
+import os
+import random
+import re
+import unittest
+import urllib.parse
+from distutils.util import strtobool
+from io import BytesIO, StringIO
+from pathlib import Path
+from typing import Union
+
+import numpy as np
+
+import PIL.Image
+import PIL.ImageOps
+import requests
+from packaging import version
+
+from .import_utils import is_flax_available, is_onnx_available, is_torch_available
+
+
+global_rng = random.Random()
+
+
+if is_torch_available():
+ import torch
+
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+ is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse(
+ "1.12"
+ )
+
+ if is_torch_higher_equal_than_1_12:
+ # Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details
+ mps_backend_registered = hasattr(torch.backends, "mps")
+ torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
+
+
+def torch_all_close(a, b, *args, **kwargs):
+ if not is_torch_available():
+ raise ValueError("PyTorch needs to be installed to use this function.")
+ if not torch.allclose(a, b, *args, **kwargs):
+ assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}."
+ return True
+
+
+def get_tests_dir(append_path=None):
+ """
+ Args:
+ append_path: optional path to append to the tests dir path
+ Return:
+ The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
+ joined after the `tests` dir the former is provided.
+ """
+ # this function caller's __file__
+ caller__file__ = inspect.stack()[1][1]
+ tests_dir = os.path.abspath(os.path.dirname(caller__file__))
+
+ while not tests_dir.endswith("tests"):
+ tests_dir = os.path.dirname(tests_dir)
+
+ if append_path:
+ return os.path.join(tests_dir, append_path)
+ else:
+ return tests_dir
+
+
+def parse_flag_from_env(key, default=False):
+ try:
+ value = os.environ[key]
+ except KeyError:
+ # KEY isn't set, default to `default`.
+ _value = default
+ else:
+ # KEY is set, convert it to True or False.
+ try:
+ _value = strtobool(value)
+ except ValueError:
+ # More values are supported, but let's keep the message simple.
+ raise ValueError(f"If set, {key} must be yes or no.")
+ return _value
+
+
+_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
+
+
+def floats_tensor(shape, scale=1.0, rng=None, name=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ total_dims = 1
+ for dim in shape:
+ total_dims *= dim
+
+ values = []
+ for _ in range(total_dims):
+ values.append(rng.random() * scale)
+
+ return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
+
+
+def slow(test_case):
+ """
+ Decorator marking a test as slow.
+
+ Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
+
+ """
+ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
+
+
+def require_torch(test_case):
+ """
+ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
+ """
+ return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
+
+
+def require_torch_gpu(test_case):
+ """Decorator marking a test that requires CUDA and PyTorch."""
+ return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
+ test_case
+ )
+
+
+def require_flax(test_case):
+ """
+ Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
+ """
+ return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
+
+
+def require_onnxruntime(test_case):
+ """
+ Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
+ """
+ return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
+
+
+def load_numpy(arry: Union[str, np.ndarray]) -> np.ndarray:
+ if isinstance(arry, str):
+ if arry.startswith("http://") or arry.startswith("https://"):
+ response = requests.get(arry)
+ response.raise_for_status()
+ arry = np.load(BytesIO(response.content))
+ elif os.path.isfile(arry):
+ arry = np.load(arry)
+ else:
+ raise ValueError(
+ f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path"
+ )
+ elif isinstance(arry, np.ndarray):
+ pass
+ else:
+ raise ValueError(
+ "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a"
+ " ndarray."
+ )
+
+ return arry
+
+
+def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
+ """
+ Args:
+ Loads `image` to a PIL Image.
+ image (`str` or `PIL.Image.Image`):
+ The image to convert to the PIL Image format.
+ Returns:
+ `PIL.Image.Image`: A PIL Image.
+ """
+ if isinstance(image, str):
+ if image.startswith("http://") or image.startswith("https://"):
+ image = PIL.Image.open(requests.get(image, stream=True).raw)
+ elif os.path.isfile(image):
+ image = PIL.Image.open(image)
+ else:
+ raise ValueError(
+ f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
+ )
+ elif isinstance(image, PIL.Image.Image):
+ image = image
+ else:
+ raise ValueError(
+ "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
+ )
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+
+def load_hf_numpy(path) -> np.ndarray:
+ if not path.startswith("http://") or path.startswith("https://"):
+ path = os.path.join(
+ "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path)
+ )
+
+ return load_numpy(path)
+
+
+# --- pytest conf functions --- #
+
+# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
+pytest_opt_registered = {}
+
+
+def pytest_addoption_shared(parser):
+ """
+ This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
+
+ It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
+ option.
+
+ """
+ option = "--make-reports"
+ if option not in pytest_opt_registered:
+ parser.addoption(
+ option,
+ action="store",
+ default=False,
+ help="generate report files. The value of this option is used as a prefix to report names",
+ )
+ pytest_opt_registered[option] = 1
+
+
+def pytest_terminal_summary_main(tr, id):
+ """
+ Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
+ directory. The report files are prefixed with the test suite name.
+
+ This function emulates --duration and -rA pytest arguments.
+
+ This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
+ there.
+
+ Args:
+ - tr: `terminalreporter` passed from `conftest.py`
+ - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
+ needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
+
+ NB: this functions taps into a private _pytest API and while unlikely, it could break should
+ pytest do internal changes - also it calls default internal methods of terminalreporter which
+ can be hijacked by various `pytest-` plugins and interfere.
+
+ """
+ from _pytest.config import create_terminal_writer
+
+ if not len(id):
+ id = "tests"
+
+ config = tr.config
+ orig_writer = config.get_terminal_writer()
+ orig_tbstyle = config.option.tbstyle
+ orig_reportchars = tr.reportchars
+
+ dir = "reports"
+ Path(dir).mkdir(parents=True, exist_ok=True)
+ report_files = {
+ k: f"{dir}/{id}_{k}.txt"
+ for k in [
+ "durations",
+ "errors",
+ "failures_long",
+ "failures_short",
+ "failures_line",
+ "passes",
+ "stats",
+ "summary_short",
+ "warnings",
+ ]
+ }
+
+ # custom durations report
+ # note: there is no need to call pytest --durations=XX to get this separate report
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
+ dlist = []
+ for replist in tr.stats.values():
+ for rep in replist:
+ if hasattr(rep, "duration"):
+ dlist.append(rep)
+ if dlist:
+ dlist.sort(key=lambda x: x.duration, reverse=True)
+ with open(report_files["durations"], "w") as f:
+ durations_min = 0.05 # sec
+ f.write("slowest durations\n")
+ for i, rep in enumerate(dlist):
+ if rep.duration < durations_min:
+ f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
+ break
+ f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
+
+ def summary_failures_short(tr):
+ # expecting that the reports were --tb=long (default) so we chop them off here to the last frame
+ reports = tr.getreports("failed")
+ if not reports:
+ return
+ tr.write_sep("=", "FAILURES SHORT STACK")
+ for rep in reports:
+ msg = tr._getfailureheadline(rep)
+ tr.write_sep("_", msg, red=True, bold=True)
+ # chop off the optional leading extra frames, leaving only the last one
+ longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
+ tr._tw.line(longrepr)
+ # note: not printing out any rep.sections to keep the report short
+
+ # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
+ # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
+ # pytest-instafail does that)
+
+ # report failures with line/short/long styles
+ config.option.tbstyle = "auto" # full tb
+ with open(report_files["failures_long"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ # config.option.tbstyle = "short" # short tb
+ with open(report_files["failures_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ summary_failures_short(tr)
+
+ config.option.tbstyle = "line" # one line per error
+ with open(report_files["failures_line"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ with open(report_files["errors"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_errors()
+
+ with open(report_files["warnings"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_warnings() # normal warnings
+ tr.summary_warnings() # final warnings
+
+ tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
+ with open(report_files["passes"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_passes()
+
+ with open(report_files["summary_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.short_test_summary()
+
+ with open(report_files["stats"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_stats()
+
+ # restore:
+ tr._tw = orig_writer
+ tr.reportchars = orig_reportchars
+ config.option.tbstyle = orig_tbstyle
+
+
+class CaptureLogger:
+ """
+ Args:
+ Context manager to capture `logging` streams
+ logger: 'logging` logger object
+ Returns:
+ The captured output is available via `self.out`
+ Example:
+ ```python
+ >>> from diffusers import logging
+ >>> from diffusers.testing_utils import CaptureLogger
+
+ >>> msg = "Testing 1, 2, 3"
+ >>> logging.set_verbosity_info()
+ >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
+ >>> with CaptureLogger(logger) as cl:
+ ... logger.info(msg)
+ >>> assert cl.out, msg + "\n"
+ ```
+ """
+
+ def __init__(self, logger):
+ self.logger = logger
+ self.io = StringIO()
+ self.sh = logging.StreamHandler(self.io)
+ self.out = ""
+
+ def __enter__(self):
+ self.logger.addHandler(self.sh)
+ return self
+
+ def __exit__(self, *exc):
+ self.logger.removeHandler(self.sh)
+ self.out = self.io.getvalue()
+
+ def __repr__(self):
+ return f"captured: {self.out}\n"
diff --git a/inputs/00003.png b/inputs/00003.png
new file mode 100644
index 0000000000000000000000000000000000000000..00cad23adf5d658caf03a0a2874f0c89d96c5ddc
Binary files /dev/null and b/inputs/00003.png differ
diff --git a/inputs/00017_gray.png b/inputs/00017_gray.png
new file mode 100644
index 0000000000000000000000000000000000000000..79af68e8aa0f036211734b7271633d88b2fc8f0d
Binary files /dev/null and b/inputs/00017_gray.png differ
diff --git a/inputs/0014.jpg b/inputs/0014.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f59554fe3143b3ffa27d6fcb04143124b4d0412b
Binary files /dev/null and b/inputs/0014.jpg differ
diff --git a/inputs/0030.jpg b/inputs/0030.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..61868926af738046e984bcf652134e3ea9b958d9
Binary files /dev/null and b/inputs/0030.jpg differ
diff --git a/inputs/ADE_val_00000114.jpg b/inputs/ADE_val_00000114.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b4d9c9067adbcdd153527cef2c0cab4cf40bbfa5
Binary files /dev/null and b/inputs/ADE_val_00000114.jpg differ
diff --git a/inputs/OST_009.png b/inputs/OST_009.png
new file mode 100644
index 0000000000000000000000000000000000000000..10bbc831acb7065827a14eb7e0538312a8d6f3e2
Binary files /dev/null and b/inputs/OST_009.png differ
diff --git a/inputs/children-alpha.png b/inputs/children-alpha.png
new file mode 100644
index 0000000000000000000000000000000000000000..41dcc3b6cc7a8a1b073f6dbe09d0c12e18c1b4b3
Binary files /dev/null and b/inputs/children-alpha.png differ
diff --git a/inputs/tree_alpha_16bit.png b/inputs/tree_alpha_16bit.png
new file mode 100644
index 0000000000000000000000000000000000000000..ca7c2aac2c5c9cdaea66ecc8e06d6b43e3d8bf20
Binary files /dev/null and b/inputs/tree_alpha_16bit.png differ
diff --git a/inputs/video/onepiece_demo.mp4 b/inputs/video/onepiece_demo.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..29b4e5246b19008885611c23921fe4423f17e43f
Binary files /dev/null and b/inputs/video/onepiece_demo.mp4 differ
diff --git a/inputs/wolf_gray.jpg b/inputs/wolf_gray.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..614766bdbcaa3730a8191afcb9616305381245ea
Binary files /dev/null and b/inputs/wolf_gray.jpg differ
diff --git a/repo.py b/repo.py
new file mode 100644
index 0000000000000000000000000000000000000000..dce0d03ec429983ae59b35e1846142e7d104fc2d
--- /dev/null
+++ b/repo.py
@@ -0,0 +1,282 @@
+import os
+import random
+
+import autocuda
+from pyabsa.utils.pyabsa_utils import fprint
+
+from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, \
+ DPMSolverMultistepScheduler
+import gradio as gr
+import torch
+from PIL import Image
+import utils
+import datetime
+import time
+import psutil
+
+from Waifu2x.magnify import ImageMagnifier
+
+start_time = time.time()
+is_colab = utils.is_google_colab()
+
+device = autocuda.auto_cuda()
+
+magnifier = ImageMagnifier()
+
+class Model:
+ def __init__(self, name, path="", prefix=""):
+ self.name = name
+ self.path = path
+ self.prefix = prefix
+ self.pipe_t2i = None
+ self.pipe_i2i = None
+
+
+models = [
+ # Model("anything v3", "anything-v3.0", "anything v3 style"),
+ Model("anything v3", "Linaqruf/anything-v3.0", "anything v3 style"),
+]
+# Model("Spider-Verse", "nitrosocke/spider-verse-diffusion", "spiderverse style "),
+# Model("Balloon Art", "Fictiverse/Stable_Diffusion_BalloonArt_Model", "BalloonArt "),
+# Model("Elden Ring", "nitrosocke/elden-ring-diffusion", "elden ring style "),
+# Model("Tron Legacy", "dallinmackay/Tron-Legacy-diffusion", "trnlgcy ")
+# Model("Pokémon", "lambdalabs/sd-pokemon-diffusers", ""),
+# Model("Pony Diffusion", "AstraliteHeart/pony-diffusion", ""),
+# Model("Robo Diffusion", "nousr/robo-diffusion", ""),
+
+scheduler = DPMSolverMultistepScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ num_train_timesteps=1000,
+ trained_betas=None,
+ predict_epsilon=True,
+ thresholding=False,
+ algorithm_type="dpmsolver++",
+ solver_type="midpoint",
+ lower_order_final=True,
+)
+
+custom_model = None
+if is_colab:
+ models.insert(0, Model("Custom model"))
+ custom_model = models[0]
+
+last_mode = "txt2img"
+current_model = models[1] if is_colab else models[0]
+current_model_path = current_model.path
+
+if is_colab:
+ pipe = StableDiffusionPipeline.from_pretrained(current_model.path, torch_dtype=torch.float16, scheduler=scheduler,
+ safety_checker=lambda images, clip_input: (images, False))
+
+else: # download all models
+ print(f"{datetime.datetime.now()} Downloading vae...")
+ vae = AutoencoderKL.from_pretrained(current_model.path, subfolder="vae", torch_dtype=torch.float32
+ )
+ for model in models:
+ try:
+ print(f"{datetime.datetime.now()} Downloading {model.name} model...")
+ unet = UNet2DConditionModel.from_pretrained(model.path, subfolder="unet", torch_dtype=torch.float32
+ )
+ model.pipe_t2i = StableDiffusionPipeline.from_pretrained(model.path, unet=unet, vae=vae,
+ torch_dtype=torch.float32,
+ scheduler=scheduler)
+# model.pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(model.path, unet=unet, vae=vae,
+# torch_dtype=torch.float32,
+# scheduler=scheduler)
+ except Exception as e:
+ print(f"{datetime.datetime.now()} Failed to load model " + model.name + ": " + str(e))
+ models.remove(model)
+ pipe = models[0].pipe_t2i
+
+if torch.cuda.is_available():
+ pipe = pipe.to(device)
+
+device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
+
+
+def error_str(error, title="Error"):
+ return f"""#### {title}
+ {error}""" if error else ""
+
+
+def custom_model_changed(path):
+ models[0].path = path
+ global current_model
+ current_model = models[0]
+
+
+def on_model_change(model_name):
+ prefix = "Enter prompt. \"" + next((m.prefix for m in models if m.name == model_name),
+ None) + "\" is prefixed automatically" if model_name != models[
+ 0].name else "Don't forget to use the custom model prefix in the prompt!"
+
+ return gr.update(visible=model_name == models[0].name), gr.update(placeholder=prefix)
+
+
+def inference(model_name, prompt, guidance, steps, width=512, height=512, seed=0, img=None, strength=0.5,
+ neg_prompt=""):
+ print(psutil.virtual_memory()) # print memory usage
+
+ global current_model
+ for model in models:
+ if model.name == model_name:
+ current_model = model
+ model_path = current_model.path
+
+ generator = torch.Generator('cuda').manual_seed(seed) if seed != 0 else None
+
+ try:
+ if img is not None:
+ return img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, width, height,
+ generator), None
+ else:
+ return txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, generator), None
+ except Exception as e:
+ fprint(e)
+ return None, error_str(e)
+
+
+def txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, generator):
+ print(f"{datetime.datetime.now()} txt_to_img, model: {current_model.name}")
+
+ global last_mode
+ global pipe
+ global current_model_path
+ if model_path != current_model_path or last_mode != "txt2img":
+ current_model_path = model_path
+
+ if is_colab or current_model == custom_model:
+ pipe = StableDiffusionPipeline.from_pretrained(current_model_path, torch_dtype=torch.float32,
+ scheduler=scheduler,
+ safety_checker=lambda images, clip_input: (images, False))
+ else:
+ pipe = pipe.to("cpu")
+ pipe = current_model.pipe_t2i
+
+ if torch.cuda.is_available():
+ pipe = pipe.to(device)
+ last_mode = "txt2img"
+
+ prompt = current_model.prefix + prompt
+ result = pipe(
+ prompt,
+ negative_prompt=neg_prompt,
+ # num_images_per_prompt=n_images,
+ num_inference_steps=int(steps),
+ guidance_scale=guidance,
+ width=width,
+ height=height,
+ generator=generator)
+ #result.images[0] = magnifier.magnify(result.images[0])
+ #result.images[0] = magnifier.magnify(result.images[0])
+
+ # save image
+ result.images[0].save("{}/{}.{}.{}.{}.{}.{}.{}.{}.png".format(saved_path,
+ datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
+ model_name,
+ prompt,
+ guidance,
+ steps,
+ width,
+ height,
+ seed)
+ )
+ return replace_nsfw_images(result)
+
+
+def img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, width, height, generator):
+ print(f"{datetime.datetime.now()} img_to_img, model: {model_path}")
+
+ global last_mode
+ global pipe
+ global current_model_path
+ if model_path != current_model_path or last_mode != "img2img":
+ current_model_path = model_path
+
+ if is_colab or current_model == custom_model:
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(current_model_path, torch_dtype=torch.float32,
+ scheduler=scheduler,
+ safety_checker=lambda images, clip_input: (
+ images, False))
+ else:
+ pipe = pipe.to("cpu")
+ pipe = current_model.pipe_i2i
+
+ if torch.cuda.is_available():
+ pipe = pipe.to(device)
+ last_mode = "img2img"
+
+ prompt = current_model.prefix + prompt
+ ratio = min(height / img.height, width / img.width)
+ img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS)
+ result = pipe(
+ prompt,
+ negative_prompt=neg_prompt,
+ # num_images_per_prompt=n_images,
+ init_image=img,
+ num_inference_steps=int(steps),
+ strength=strength,
+ guidance_scale=guidance,
+ width=width,
+ height=height,
+ generator=generator)
+ result.images[0] = magnifier.magnify(result.images[0])
+ result.images[0] = magnifier.magnify(result.images[0])
+
+ # save image
+ result.images[0].save("{}/{}.{}.{}.{}.{}.{}.{}.{}.png".format(saved_path,
+ datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
+ model_name,
+ prompt,
+ guidance,
+ steps,
+ width,
+ height,
+ seed)
+ )
+ return replace_nsfw_images(result)
+
+
+def replace_nsfw_images(results):
+ if is_colab:
+ return results.images[0]
+
+ for i in range(len(results.images)):
+ if results.nsfw_content_detected[i]:
+ results.images[i] = Image.open("nsfw.png")
+ return results.images[0]
+
+
+if __name__ == "__main__":
+ # inference("DALL-E", "a dog", 0, 1000, 512, 512, 0, None, 0.5, "")
+ model_name = "anything v3"
+ saved_path = r"imgs"
+ if not os.path.exists(saved_path):
+ os.mkdir(saved_path)
+ n = 0
+ while True:
+ prompt_keys = [
+ 'beautiful eyes', 'cumulonimbus clouds', 'sky', 'detailed fingers',
+ random.choice(['white hair', 'red hair', 'blonde hair', 'black hair', 'green hair', ]),
+ random.choice(['blue eyes', 'green eyes', 'red eyes', 'black eyes', 'yellow eyes', ]),
+ random.choice(['flower meadow', 'garden', 'city', 'river', 'beach']),
+ random.choice(['Elif', 'Angel'])
+ ]
+ guidance = 7.5
+ steps = 25
+ # width = 1024
+ # height = 1024
+ # width = 768
+ # height = 1024
+ width = 512
+ height = 888
+ seed = 0
+ img = None
+ strength = 0.5
+ neg_prompt = ""
+ inference(model_name, '.'.join(prompt_keys), guidance, steps, width=width, height=height, seed=seed, img=img,
+ strength=strength, neg_prompt=neg_prompt)
+ n += 1
+ fprint(n)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..acf4e52a24482afdf80ef741a8fcd02a2551f092
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,22 @@
+--extra-index-url https://download.pytorch.org/whl/cu113
+torch
+torchvision
+autocuda
+findfile
+pyabsa
+git+https://github.com/huggingface/diffusers.git
+scipy
+git+https://github.com/huggingface/transformers.git
+ftfy
+accelerate
+psutil
+gradio
+
+# requirements for RealESRGAN
+basicsr>=1.4.2
+facexlib>=0.2.5
+gfpgan>=1.3.5
+numpy
+opencv-python
+Pillow
+tqdm
diff --git a/results/00003_out.png b/results/00003_out.png
new file mode 100644
index 0000000000000000000000000000000000000000..133455e9c204f2cbde887307630e4eef885e6219
--- /dev/null
+++ b/results/00003_out.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d6e10c1078b0c1e09c77c3d6246a47cfdefb41c1cb976bf80e6fc17c415c29dd
+size 2499904
diff --git a/run_webUI.py b/run_webUI.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f3ae7ef02abac582ae58fcf66aa6170ce0956bb
--- /dev/null
+++ b/run_webUI.py
@@ -0,0 +1,11 @@
+import os
+
+try:
+ from RealESRGANv030.interface import realEsrgan
+except:
+ os.system('cd RealESRGANv030 && python setup.py develop')
+
+status = os.system('python app.py')
+if status!=0:
+ print('Run failed, try set MKL_THREADING_LAYER=GNU\n')
+ os.system('export MKL_THREADING_LAYER=GNU && python app.py')
\ No newline at end of file
diff --git a/style.css b/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..79d03d4905e60bebf2f357a1ba97ffedd0f0d5f5
--- /dev/null
+++ b/style.css
@@ -0,0 +1,22 @@
+.finetuned-diffusion-div div{
+ display:inline-flex;
+ align-items:center;
+ gap:.8rem;
+ font-size:1.75rem
+ }
+
+.finetuned-diffusion-div div h1{
+ font-weight:900;
+ margin-bottom:7px
+ }
+
+.finetuned-diffusion-div p{
+ margin-bottom:10px;
+ font-size:94%}
+
+a{text-decoration:underline}
+
+.tabs{margin-top:0;
+ margin-bottom:0}
+
+#gallery{min-height:20rem}
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff1c065d186347ca51b47d010a697dbe1814695c
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,6 @@
+def is_google_colab():
+ try:
+ import google.colab
+ return True
+ except:
+ return False
\ No newline at end of file