Spaces:
Runtime error
Runtime error
First commit
Browse files- .gitattributes +0 -34
- .github/workflows/update_space.yml +28 -0
- .gitignore +216 -0
- .idea/.gitignore +8 -0
- .idea/DeepLearning.iml +8 -0
- .idea/deployment.xml +29 -0
- .idea/inspectionProfiles/Project_Default.xml +34 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +7 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- .python-version +1 -0
- PikaPikaTraining.ipynb +112 -0
- pikapikagen/PikaPikaGen.ipynb +2241 -0
- pikapikagen/README.md +6 -0
- pikapikagen/__init__.py +0 -0
- pikapikagen/data_loader.py +100 -0
- pikapikagen/dataset.py +141 -0
- pikapikagen/discriminators.py +161 -0
- pikapikagen/evaluate_kid.py +141 -0
- pikapikagen/gradio_demo.py +291 -0
- pikapikagen/losses.py +103 -0
- pikapikagen/model.py +46 -0
- pikapikagen/model_blocks/decoder_block.py +59 -0
- pikapikagen/model_blocks/image_cross_attention.py +49 -0
- pikapikagen/model_blocks/image_decoder.py +122 -0
- pikapikagen/model_blocks/text_encoder.py +43 -0
- pikapikagen/model_checkpoint/checkpoint_epoch_150.pth +3 -0
- pikapikagen/plots.py +428 -0
- pikapikagen/utils.py +12 -0
- pyproject.toml +18 -0
- uv.lock +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.github/workflows/update_space.yml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Run Python script
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
build:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
|
| 12 |
+
steps:
|
| 13 |
+
- name: Checkout
|
| 14 |
+
uses: actions/checkout@v2
|
| 15 |
+
|
| 16 |
+
- name: Set up Python
|
| 17 |
+
uses: actions/setup-python@v2
|
| 18 |
+
with:
|
| 19 |
+
python-version: '3.9'
|
| 20 |
+
|
| 21 |
+
- name: Install Gradio
|
| 22 |
+
run: python -m pip install gradio
|
| 23 |
+
|
| 24 |
+
- name: Log in to Hugging Face
|
| 25 |
+
run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
|
| 26 |
+
|
| 27 |
+
- name: Deploy to Spaces
|
| 28 |
+
run: gradio deploy
|
.gitignore
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
#poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
#pdm.lock
|
| 116 |
+
#pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
#pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.envrc
|
| 140 |
+
.venv
|
| 141 |
+
env/
|
| 142 |
+
venv/
|
| 143 |
+
ENV/
|
| 144 |
+
env.bak/
|
| 145 |
+
venv.bak/
|
| 146 |
+
|
| 147 |
+
# Spyder project settings
|
| 148 |
+
.spyderproject
|
| 149 |
+
.spyproject
|
| 150 |
+
|
| 151 |
+
# Rope project settings
|
| 152 |
+
.ropeproject
|
| 153 |
+
|
| 154 |
+
# mkdocs documentation
|
| 155 |
+
/site
|
| 156 |
+
|
| 157 |
+
# mypy
|
| 158 |
+
.mypy_cache/
|
| 159 |
+
.dmypy.json
|
| 160 |
+
dmypy.json
|
| 161 |
+
|
| 162 |
+
# Pyre type checker
|
| 163 |
+
.pyre/
|
| 164 |
+
|
| 165 |
+
# pytype static type analyzer
|
| 166 |
+
.pytype/
|
| 167 |
+
|
| 168 |
+
# Cython debug symbols
|
| 169 |
+
cython_debug/
|
| 170 |
+
|
| 171 |
+
# PyCharm
|
| 172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 176 |
+
#.idea/
|
| 177 |
+
|
| 178 |
+
# Abstra
|
| 179 |
+
# Abstra is an AI-powered process automation framework.
|
| 180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 181 |
+
# Learn more at https://abstra.io/docs
|
| 182 |
+
.abstra/
|
| 183 |
+
|
| 184 |
+
# Visual Studio Code
|
| 185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 189 |
+
# .vscode/
|
| 190 |
+
|
| 191 |
+
# Ruff stuff:
|
| 192 |
+
.ruff_cache/
|
| 193 |
+
|
| 194 |
+
# PyPI configuration file
|
| 195 |
+
.pypirc
|
| 196 |
+
|
| 197 |
+
# Cursor
|
| 198 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 199 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 200 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 201 |
+
.cursorignore
|
| 202 |
+
.cursorindexingignore
|
| 203 |
+
|
| 204 |
+
# Marimo
|
| 205 |
+
marimo/_static/
|
| 206 |
+
marimo/_lsp/
|
| 207 |
+
__marimo__/
|
| 208 |
+
|
| 209 |
+
# Streamlit
|
| 210 |
+
.streamlit/secrets.toml
|
| 211 |
+
|
| 212 |
+
# Project
|
| 213 |
+
/pikapikagen/dataset
|
| 214 |
+
/pikapikagen/training_output
|
| 215 |
+
/dataset
|
| 216 |
+
/old_notebooks/dataset
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
.idea/DeepLearning.iml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="jdk" jdkName="Python 3.12 virtualenv at C:\Users\valer\Mega\Programming\DeepLearning\.venv" jdkType="Python SDK" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
</module>
|
.idea/deployment.xml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="PublishConfigData" autoUpload="Always" remoteFilesAllowedToDisappearOnAutoupload="false">
|
| 4 |
+
<serverData>
|
| 5 |
+
<paths name="root@salad:22 agent">
|
| 6 |
+
<serverdata>
|
| 7 |
+
<mappings>
|
| 8 |
+
<mapping deploy="/tmp/pycharm_project_71" local="$PROJECT_DIR$" />
|
| 9 |
+
</mappings>
|
| 10 |
+
</serverdata>
|
| 11 |
+
</paths>
|
| 12 |
+
<paths name="val@46.101.132.64:22 key">
|
| 13 |
+
<serverdata>
|
| 14 |
+
<mappings>
|
| 15 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
| 16 |
+
</mappings>
|
| 17 |
+
</serverdata>
|
| 18 |
+
</paths>
|
| 19 |
+
<paths name="val@46.101.132.64:22 key (2)">
|
| 20 |
+
<serverdata>
|
| 21 |
+
<mappings>
|
| 22 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
| 23 |
+
</mappings>
|
| 24 |
+
</serverdata>
|
| 25 |
+
</paths>
|
| 26 |
+
</serverData>
|
| 27 |
+
<option name="myAutoUpload" value="ALWAYS" />
|
| 28 |
+
</component>
|
| 29 |
+
</project>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<profile version="1.0">
|
| 3 |
+
<option name="myName" value="Project Default" />
|
| 4 |
+
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
| 5 |
+
<Languages>
|
| 6 |
+
<language minSize="49" name="Python" />
|
| 7 |
+
</Languages>
|
| 8 |
+
</inspection_tool>
|
| 9 |
+
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
| 10 |
+
<inspection_tool class="Mypy" enabled="true" level="TYPO" enabled_by_default="true" editorAttributes="TYPO" />
|
| 11 |
+
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
| 12 |
+
<option name="ignoredErrors">
|
| 13 |
+
<list>
|
| 14 |
+
<option value="N802" />
|
| 15 |
+
<option value="N803" />
|
| 16 |
+
<option value="N806" />
|
| 17 |
+
</list>
|
| 18 |
+
</option>
|
| 19 |
+
</inspection_tool>
|
| 20 |
+
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
| 21 |
+
<option name="ignoredIdentifiers">
|
| 22 |
+
<list>
|
| 23 |
+
<option value="fitz.fitz.Page.MediaBox" />
|
| 24 |
+
<option value="color_tol" />
|
| 25 |
+
</list>
|
| 26 |
+
</option>
|
| 27 |
+
</inspection_tool>
|
| 28 |
+
<inspection_tool class="SpellCheckingInspection" enabled="false" level="TYPO" enabled_by_default="false">
|
| 29 |
+
<option name="processCode" value="true" />
|
| 30 |
+
<option name="processLiterals" value="true" />
|
| 31 |
+
<option name="processComments" value="true" />
|
| 32 |
+
</inspection_tool>
|
| 33 |
+
</profile>
|
| 34 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="Black">
|
| 4 |
+
<option name="sdkName" value="Python 3.12 virtualenv at C:\Users\valer\Mega\Programming\DeepLearning\.venv" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 virtualenv at C:\Users\valer\Mega\Programming\DeepLearning\.venv" project-jdk-type="Python SDK" />
|
| 7 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/DeepLearning.iml" filepath="$PROJECT_DIR$/.idea/DeepLearning.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
PikaPikaTraining.ipynb
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# PikaPikaGen: Training del Modello\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Questo notebook automatizza il processo di setup e avvio del training per il modello PikaPikaGen.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"I passaggi eseguiti sono:\n",
|
| 12 |
+
"1. Clonazione del repository GitHub pubblico.\n",
|
| 13 |
+
"2. Installazione delle dipendenze necessarie tramite `uv`.\n",
|
| 14 |
+
"3. Esecuzione dello script di training `main.py`."
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": null,
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"print(\"Installazione delle dipendenze necessarie...\")\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"# Assicurati che uv sia installato\n",
|
| 26 |
+
"%pip install uv\n",
|
| 27 |
+
"print(\"✅ uv installato con successo.\")\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"# Controlla se torch è già installato\n",
|
| 30 |
+
"try:\n",
|
| 31 |
+
" import torch\n",
|
| 32 |
+
" print(f\"✅ PyTorch già installato (versione: {torch.__version__})\")\n",
|
| 33 |
+
" torch_installed = True\n",
|
| 34 |
+
"except ImportError:\n",
|
| 35 |
+
" print(\"❌ PyTorch non trovato, sarà installato\")\n",
|
| 36 |
+
" torch_installed = False\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"# Lista delle dipendenze principali del progetto\n",
|
| 39 |
+
"dependencies = [\n",
|
| 40 |
+
" \"transformers\",\n",
|
| 41 |
+
" \"pandas\",\n",
|
| 42 |
+
" \"tqdm\",\n",
|
| 43 |
+
" \"matplotlib\",\n",
|
| 44 |
+
" \"Pillow\",\n",
|
| 45 |
+
" \"requests\",\n",
|
| 46 |
+
" \"ipywidgets\"\n",
|
| 47 |
+
"]\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"# Aggiungi torch e torchvision solo se non sono già installati\n",
|
| 50 |
+
"if not torch_installed:\n",
|
| 51 |
+
" dependencies.extend([\"torch\", \"torchvision\"])\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"print(\"Installazione delle dipendenze con uv...\")\n",
|
| 54 |
+
"deps_str = \" \".join(dependencies)\n",
|
| 55 |
+
"if torch_installed:\n",
|
| 56 |
+
" !uv pip install {deps_str}\n",
|
| 57 |
+
"else:\n",
|
| 58 |
+
" !uv pip install {deps_str} --torch-backend=auto\n",
|
| 59 |
+
"print(\"✅ Dipendenze principali installate con successo.\")\n"
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"cell_type": "code",
|
| 64 |
+
"execution_count": null,
|
| 65 |
+
"metadata": {},
|
| 66 |
+
"outputs": [],
|
| 67 |
+
"source": [
|
| 68 |
+
"import os\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"repo_url = \"https://github.com/val-2/DeepLearning\"\n",
|
| 71 |
+
"branch = \"main\"\n",
|
| 72 |
+
"repo_name = repo_url.split('/')[-1]\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"print(f\"Clonazione del repository: {repo_url}\")\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"# Check if we're already in the repo directory\n",
|
| 77 |
+
"current_dir = os.path.basename(os.getcwd())\n",
|
| 78 |
+
"if current_dir == repo_name:\n",
|
| 79 |
+
" print(f\"Già nella directory del repository '{repo_name}'. Aggiornamento...\")\n",
|
| 80 |
+
" !git fetch\n",
|
| 81 |
+
" !git pull\n",
|
| 82 |
+
" !git checkout {branch}\n",
|
| 83 |
+
"elif os.path.exists(repo_name):\n",
|
| 84 |
+
" print(f\"La directory '{repo_name}' esiste già. Aggiornamento del repository...\")\n",
|
| 85 |
+
" os.chdir(repo_name)\n",
|
| 86 |
+
" !git fetch\n",
|
| 87 |
+
" !git pull\n",
|
| 88 |
+
" !git checkout {branch}\n",
|
| 89 |
+
"else:\n",
|
| 90 |
+
" print(\"Clonazione del repository...\")\n",
|
| 91 |
+
" !git clone -b {branch} {repo_url}\n",
|
| 92 |
+
" os.chdir(repo_name)\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"# Spostati nella directory del repository\n",
|
| 95 |
+
"print(f\"Directory di lavoro corrente: {os.getcwd()}\")"
|
| 96 |
+
]
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"metadata": {
|
| 100 |
+
"kernelspec": {
|
| 101 |
+
"display_name": ".venv",
|
| 102 |
+
"language": "python",
|
| 103 |
+
"name": "python3"
|
| 104 |
+
},
|
| 105 |
+
"language_info": {
|
| 106 |
+
"name": "python",
|
| 107 |
+
"version": "3.12.11"
|
| 108 |
+
}
|
| 109 |
+
},
|
| 110 |
+
"nbformat": 4,
|
| 111 |
+
"nbformat_minor": 2
|
| 112 |
+
}
|
pikapikagen/PikaPikaGen.ipynb
ADDED
|
@@ -0,0 +1,2241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "raw",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"id": "VDSaH9SVsnNl",
|
| 7 |
+
"vscode": {
|
| 8 |
+
"languageId": "raw"
|
| 9 |
+
}
|
| 10 |
+
},
|
| 11 |
+
"source": [
|
| 12 |
+
"# PikaPikaGen: Text-to-Image Pokemon Sprite Generation with GAN\n"
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"execution_count": null,
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [],
|
| 20 |
+
"source": [
|
| 21 |
+
"# Install required packages\n",
|
| 22 |
+
"#!pip install torch torchvision transformers pandas pillow requests matplotlib tqdm ipywidgets gradio torch-fidelity\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"import torch\n",
|
| 25 |
+
"import torch.nn as nn\n",
|
| 26 |
+
"import torch.optim as optim\n",
|
| 27 |
+
"import torch.nn.functional as F\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"import numpy as np\n",
|
| 30 |
+
"import matplotlib.pyplot as plt\n",
|
| 31 |
+
"import os\n",
|
| 32 |
+
"from tqdm import tqdm\n",
|
| 33 |
+
"from transformers import AutoTokenizer\n",
|
| 34 |
+
"import warnings\n",
|
| 35 |
+
"warnings.filterwarnings('ignore')\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"# Set device\n",
|
| 38 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 39 |
+
"print(f\"Using device: {device}\")\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"# Set random seeds for reproducibility\n",
|
| 42 |
+
"RANDOM_SEED = 42\n",
|
| 43 |
+
"torch.manual_seed(RANDOM_SEED)\n",
|
| 44 |
+
"np.random.seed(RANDOM_SEED)"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "raw",
|
| 49 |
+
"metadata": {
|
| 50 |
+
"id": "-rrtsHGqsnNo",
|
| 51 |
+
"vscode": {
|
| 52 |
+
"languageId": "raw"
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"source": [
|
| 56 |
+
"## 1. Data Loading and Preprocessing\n"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"execution_count": 4,
|
| 62 |
+
"metadata": {
|
| 63 |
+
"id": "aeVuv1YCsnNp"
|
| 64 |
+
},
|
| 65 |
+
"outputs": [],
|
| 66 |
+
"source": [
|
| 67 |
+
"import torch\n",
|
| 68 |
+
"import torchvision.transforms as T\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"class AugmentationPipeline:\n",
|
| 72 |
+
" def __init__(self, p=0.8):\n",
|
| 73 |
+
" self.p = p\n",
|
| 74 |
+
" self.transforms = T.RandomApply([\n",
|
| 75 |
+
" T.RandomHorizontalFlip(p=0.5),\n",
|
| 76 |
+
"\n",
|
| 77 |
+
" T.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05), fill=1),\n",
|
| 78 |
+
"\n",
|
| 79 |
+
" T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
|
| 80 |
+
"\n",
|
| 81 |
+
" T.RandomErasing(p=0.15, scale=(0.02, 0.1), ratio=(0.3, 3.3), value='random'),\n",
|
| 82 |
+
" ], p=self.p)\n",
|
| 83 |
+
"\n",
|
| 84 |
+
" def apply(self, images):\n",
|
| 85 |
+
" return self.transforms(images)"
|
| 86 |
+
]
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "code",
|
| 90 |
+
"execution_count": null,
|
| 91 |
+
"metadata": {
|
| 92 |
+
"colab": {
|
| 93 |
+
"base_uri": "https://localhost:8080/",
|
| 94 |
+
"height": 1000,
|
| 95 |
+
"referenced_widgets": [
|
| 96 |
+
"5efdceae0bac4c978d3a7226247e237f",
|
| 97 |
+
"a39c5c623a3e42448e109fb9ec6bc263",
|
| 98 |
+
"a6ed2ddb1c6f4d1aa945c5a39372f781",
|
| 99 |
+
"8cf950b898e142c1af9b4db92019aa4d",
|
| 100 |
+
"8ed7abd0602c43a1bfc0f96d7611d429",
|
| 101 |
+
"65ba2d78fde14bb2baf5ae1101d7e5ff",
|
| 102 |
+
"4795a78a75dc439a8da7df58bf738940",
|
| 103 |
+
"4545ff199b874d3680a83918513e1d4b",
|
| 104 |
+
"cad8fd90586443778568a1babb8c40e6",
|
| 105 |
+
"57e526d188b9414dabb3b1c895373864",
|
| 106 |
+
"8226a55726c54abba3a48dbfa8e1b6f6",
|
| 107 |
+
"86a3c1a4e9eb4989b23364f21e5df531",
|
| 108 |
+
"5ba39d9d997a45ca848e3e2ffd0e7307",
|
| 109 |
+
"4c22e1b396f342ffb90c1b50a0051862",
|
| 110 |
+
"370e5663868f411697bfb24f4e3efa09",
|
| 111 |
+
"3a338ac4d2944030a07843d8ea24e9fd",
|
| 112 |
+
"128f4312bcdc4166b9e24d8cdd34184d",
|
| 113 |
+
"1b65d6c8540e4f458886d5e7075ab30a",
|
| 114 |
+
"a5a9f8607fdd4f9cad7519eca573f3dc",
|
| 115 |
+
"926149594f94457295c60b4fad9cbac7",
|
| 116 |
+
"7e89bc79516f405e9684eacdce7b4551",
|
| 117 |
+
"c917f3a000fb44338e4afbeabeaab55f"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
"id": "ppTYW-n5snNp",
|
| 121 |
+
"outputId": "4d7a3003-296a-458c-a339-aeacf5232c91"
|
| 122 |
+
},
|
| 123 |
+
"outputs": [],
|
| 124 |
+
"source": [
|
| 125 |
+
"from data_loader import create_training_setup\n",
|
| 126 |
+
"from utils import denormalize_image\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"tokenizer = AutoTokenizer.from_pretrained('prajjwal1/bert-mini')\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"# train_augmentation_pipeline = AugmentationPipeline()\n",
|
| 131 |
+
"# Create the complete training setup using the function from pokemon_dataset.py\n",
|
| 132 |
+
"print(\"Creating training setup with train/val split and fixed batches...\")\n",
|
| 133 |
+
"training_setup = create_training_setup(\n",
|
| 134 |
+
" tokenizer=tokenizer,\n",
|
| 135 |
+
" test_set_size=0.2,\n",
|
| 136 |
+
" val_set_size=0.1,\n",
|
| 137 |
+
" batch_size=16,\n",
|
| 138 |
+
" num_workers=0,\n",
|
| 139 |
+
" num_viz_samples=4,\n",
|
| 140 |
+
" random_seed=42,\n",
|
| 141 |
+
" train_augmentation_pipeline=None\n",
|
| 142 |
+
")\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"# Extract components\n",
|
| 145 |
+
"train_loader = training_setup['train_loader']\n",
|
| 146 |
+
"val_loader = training_setup['val_loader']\n",
|
| 147 |
+
"fixed_train_batch = training_setup['fixed_train_batch']\n",
|
| 148 |
+
"fixed_val_batch = training_setup['fixed_val_batch']\n",
|
| 149 |
+
"fixed_train_attention_batch = training_setup['fixed_train_attention_batch']\n",
|
| 150 |
+
"fixed_val_attention_batch = training_setup['fixed_val_attention_batch']\n",
|
| 151 |
+
"\n",
|
| 152 |
+
"print(\"Training setup complete!\")\n",
|
| 153 |
+
"print(f\"Train loader batches: {len(train_loader)}\")\n",
|
| 154 |
+
"print(f\"Val loader batches: {len(val_loader)}\")\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"# Test the training setup with fixed batches\n",
|
| 157 |
+
"print(\"\\nFixed batch shapes:\")\n",
|
| 158 |
+
"print(f\" Train batch - Images: {fixed_train_batch['image'].shape}\")\n",
|
| 159 |
+
"print(f\" Train batch - Text: {fixed_train_batch['text'].shape}\")\n",
|
| 160 |
+
"print(f\" Train batch - Attention: {fixed_train_batch['attention_mask'].shape}\")\n",
|
| 161 |
+
"print(f\" Val batch - Images: {fixed_val_batch['image'].shape}\")\n",
|
| 162 |
+
"\n",
|
| 163 |
+
"# Display sample images from fixed batches\n",
|
| 164 |
+
"fig, axes = plt.subplots(2, 4, figsize=(16, 8))\n",
|
| 165 |
+
"for i in range(4):\n",
|
| 166 |
+
" # Fixed train batch images\n",
|
| 167 |
+
" train_img = denormalize_image(fixed_train_batch['image'][i])\n",
|
| 168 |
+
" axes[0, i].imshow(train_img.permute(1, 2, 0))\n",
|
| 169 |
+
" axes[0, i].set_title(f\"Train: {fixed_train_batch['pokemon_name'][i]}\")\n",
|
| 170 |
+
" axes[0, i].axis('off')\n",
|
| 171 |
+
"\n",
|
| 172 |
+
" # Fixed val batch images\n",
|
| 173 |
+
" val_img = denormalize_image(fixed_val_batch['image'][i])\n",
|
| 174 |
+
" axes[1, i].imshow(val_img.permute(1, 2, 0))\n",
|
| 175 |
+
" axes[1, i].set_title(f\"Val: {fixed_val_batch['pokemon_name'][i]}\")\n",
|
| 176 |
+
" axes[1, i].axis('off')\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"plt.suptitle(\"Fixed Batches for Training Visualization\", fontsize=16)\n",
|
| 179 |
+
"plt.tight_layout()\n",
|
| 180 |
+
"plt.show()\n",
|
| 181 |
+
"\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"print(\"\\n✅ Dataset and batches loaded successfully from pokemon_dataset.py functionality!\")\n",
|
| 184 |
+
"print(\"Ready for training with proper train/val split and fixed visualization batches.\")\n"
|
| 185 |
+
]
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"cell_type": "raw",
|
| 189 |
+
"metadata": {
|
| 190 |
+
"id": "eJSVrf3ysnNq",
|
| 191 |
+
"vscode": {
|
| 192 |
+
"languageId": "raw"
|
| 193 |
+
}
|
| 194 |
+
},
|
| 195 |
+
"source": [
|
| 196 |
+
"## 2. Model Architecture Implementation\n"
|
| 197 |
+
]
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"cell_type": "code",
|
| 201 |
+
"execution_count": null,
|
| 202 |
+
"metadata": {
|
| 203 |
+
"colab": {
|
| 204 |
+
"base_uri": "https://localhost:8080/",
|
| 205 |
+
"height": 923,
|
| 206 |
+
"referenced_widgets": [
|
| 207 |
+
"bdf500351aea42698c6d6dd5a99021f3",
|
| 208 |
+
"ab61b90c1a5b4a2b9bb5c9d5a215bb3f",
|
| 209 |
+
"dc03fed540b74f3aa4a1b17ebf2c81d3",
|
| 210 |
+
"5837f2c4668646c0a6db2407aebb46e3",
|
| 211 |
+
"edeb423e9ff84e5c8a0d790368d68bba",
|
| 212 |
+
"bf8eb066cdaf4ac096dc14392d085daf",
|
| 213 |
+
"4e32e76c44fb449c8cb767abeb17868a",
|
| 214 |
+
"5c3cb981f324446eae642f7c23a539f0",
|
| 215 |
+
"2fe9614fe5984fa6b887d1e1b3e18b04",
|
| 216 |
+
"64277772cc30408e8ea29f0e268c8880",
|
| 217 |
+
"5b0d55ea20714104818097bd7d1f509a",
|
| 218 |
+
"7e21c6a9c7f44496b6f28513caefb631",
|
| 219 |
+
"439eba0eb4184c0ab83f65fc26bbe388",
|
| 220 |
+
"eee695744ec64aa7b71b9e85968c6f8f",
|
| 221 |
+
"c4ecdc9d982f49129368893c1c0aece9",
|
| 222 |
+
"5f5e7ff6e4c845b99602a4fa00ad550a",
|
| 223 |
+
"304d50e74ad744cdb3a7cc88739cb923",
|
| 224 |
+
"bfcc6d01c9ff4db698afa4318e7c91ac",
|
| 225 |
+
"b2bf751bb96746e4a828241f70e52050",
|
| 226 |
+
"828b227361fe45cd83964149e7475503",
|
| 227 |
+
"58ab975eaba2485cb0945482c26ecf3d",
|
| 228 |
+
"d0b4e43ab5cd4edda6cc061b36bf10a3"
|
| 229 |
+
]
|
| 230 |
+
},
|
| 231 |
+
"id": "RnNQM3_ysnNr",
|
| 232 |
+
"outputId": "6905696e-05d6-4d97-dd9b-9dc36eea95b7"
|
| 233 |
+
},
|
| 234 |
+
"outputs": [],
|
| 235 |
+
"source": [
|
| 236 |
+
"from model import Generator\n",
|
| 237 |
+
"\n",
|
| 238 |
+
"# Test the generator\n",
|
| 239 |
+
"generator = Generator().to(device)\n",
|
| 240 |
+
"with torch.no_grad():\n",
|
| 241 |
+
" generated_images_256, generated_images_64 = generator(\n",
|
| 242 |
+
" fixed_train_batch['text'][:2].to(device),\n",
|
| 243 |
+
" fixed_train_batch['attention_mask'][:2].to(device)\n",
|
| 244 |
+
" )\n",
|
| 245 |
+
"print(f\"Generator output shape 256x256: {generated_images_256.shape}\")\n",
|
| 246 |
+
"print(f\"Generator output shape 64x64: {generated_images_64.shape}\")\n",
|
| 247 |
+
"\n",
|
| 248 |
+
"print(\"Generator test\")\n",
|
| 249 |
+
"plt.figure(figsize=(12, 8))\n",
|
| 250 |
+
"for i in range(2):\n",
|
| 251 |
+
" # 256x256 images\n",
|
| 252 |
+
" plt.subplot(2, 2, i+1)\n",
|
| 253 |
+
" img_256 = denormalize_image(generated_images_256[i].cpu())\n",
|
| 254 |
+
" plt.imshow(img_256.permute(1, 2, 0))\n",
|
| 255 |
+
" plt.title(f\"Generated 256x256 Sample {i+1}\")\n",
|
| 256 |
+
" plt.axis('off')\n",
|
| 257 |
+
"\n",
|
| 258 |
+
" # 64x64 images\n",
|
| 259 |
+
" plt.subplot(2, 2, i+3)\n",
|
| 260 |
+
" img_64 = denormalize_image(generated_images_64[i].cpu())\n",
|
| 261 |
+
" plt.imshow(img_64.permute(1, 2, 0))\n",
|
| 262 |
+
" plt.title(f\"Generated 64x64 Sample {i+1}\")\n",
|
| 263 |
+
" plt.axis('off')\n",
|
| 264 |
+
"plt.tight_layout()\n",
|
| 265 |
+
"plt.show()\n"
|
| 266 |
+
]
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"cell_type": "raw",
|
| 270 |
+
"metadata": {
|
| 271 |
+
"id": "7drCU21JsnNs",
|
| 272 |
+
"vscode": {
|
| 273 |
+
"languageId": "raw"
|
| 274 |
+
}
|
| 275 |
+
},
|
| 276 |
+
"source": [
|
| 277 |
+
"## 3. Training Setup and Utilities\n"
|
| 278 |
+
]
|
| 279 |
+
},
|
| 280 |
+
{
|
| 281 |
+
"cell_type": "code",
|
| 282 |
+
"execution_count": null,
|
| 283 |
+
"metadata": {
|
| 284 |
+
"colab": {
|
| 285 |
+
"base_uri": "https://localhost:8080/"
|
| 286 |
+
},
|
| 287 |
+
"id": "iQdhzEQQsnNs",
|
| 288 |
+
"outputId": "2dbee275-3b6d-43da-8929-97e21403821f"
|
| 289 |
+
},
|
| 290 |
+
"outputs": [],
|
| 291 |
+
"source": [
|
| 292 |
+
"from discriminators import Discriminator256, Discriminator64\n",
|
| 293 |
+
"from losses import VGGPerceptualLoss, SobelLoss\n",
|
| 294 |
+
"from plots import save_attention_visualization\n",
|
| 295 |
+
"\n",
|
| 296 |
+
"def weights_init(m):\n",
|
| 297 |
+
" \"\"\"Initialize model weights according to the original DCGAN paper\"\"\"\n",
|
| 298 |
+
" classname = m.__class__.__name__\n",
|
| 299 |
+
" if classname.find('Conv') != -1:\n",
|
| 300 |
+
" nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
|
| 301 |
+
" elif classname.find('BatchNorm') != -1:\n",
|
| 302 |
+
" nn.init.normal_(m.weight.data, 1.0, 0.02)\n",
|
| 303 |
+
" nn.init.constant_(m.bias.data, 0)\n",
|
| 304 |
+
"\n",
|
| 305 |
+
"generator = Generator().to(device)\n",
|
| 306 |
+
"discriminator_256 = Discriminator256().to(device)\n",
|
| 307 |
+
"discriminator_64 = Discriminator64().to(device)\n",
|
| 308 |
+
"\n",
|
| 309 |
+
"generator.apply(weights_init)\n",
|
| 310 |
+
"discriminator_256.apply(weights_init)\n",
|
| 311 |
+
"discriminator_64.apply(weights_init)\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"# Optimizer params\n",
|
| 315 |
+
"lr = 0.0002\n",
|
| 316 |
+
"beta1 = 0.5\n",
|
| 317 |
+
"beta2 = 0.999\n",
|
| 318 |
+
"\n",
|
| 319 |
+
"optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))\n",
|
| 320 |
+
"optimizer_D_256 = optim.Adam(discriminator_256.parameters(), lr=lr, betas=(beta1, beta2))\n",
|
| 321 |
+
"optimizer_D_64 = optim.Adam(discriminator_64.parameters(), lr=lr, betas=(beta1, beta2))\n",
|
| 322 |
+
"\n",
|
| 323 |
+
"adv_criterion = nn.BCEWithLogitsLoss().to(device) # no sigmoid at the end of discriminators\n",
|
| 324 |
+
"l1_criterion = nn.L1Loss().to(device)\n",
|
| 325 |
+
"perc_criterion = VGGPerceptualLoss(device)\n",
|
| 326 |
+
"sobel_criterion = SobelLoss().to(device)"
|
| 327 |
+
]
|
| 328 |
+
},
|
| 329 |
+
{
|
| 330 |
+
"cell_type": "code",
|
| 331 |
+
"execution_count": null,
|
| 332 |
+
"metadata": {},
|
| 333 |
+
"outputs": [],
|
| 334 |
+
"source": [
|
| 335 |
+
"from typing import TypedDict\n",
|
| 336 |
+
"import torch\n",
|
| 337 |
+
"from plots import save_comparison_grid\n",
|
| 338 |
+
"\n",
|
| 339 |
+
"# Create checkpoint saving directory\n",
|
| 340 |
+
"os.makedirs('models', exist_ok=True)\n",
|
| 341 |
+
"\n",
|
| 342 |
+
"# TypedDicts to pass and return many object at once, without\n",
|
| 343 |
+
"class LossesDict(TypedDict):\n",
|
| 344 |
+
" \"\"\"History of training losses\"\"\"\n",
|
| 345 |
+
" generator: list[float]\n",
|
| 346 |
+
" discriminator: list[float]\n",
|
| 347 |
+
" l1: list[float]\n",
|
| 348 |
+
" perceptual: list[float]\n",
|
| 349 |
+
" sobel: list[float]\n",
|
| 350 |
+
"\n",
|
| 351 |
+
"class ValidationLossesDict(TypedDict):\n",
|
| 352 |
+
" \"\"\"History of validation losses\"\"\"\n",
|
| 353 |
+
" l1: list[float]\n",
|
| 354 |
+
" perceptual: list[float]\n",
|
| 355 |
+
" sobel: list[float]\n",
|
| 356 |
+
" total: list[float]\n",
|
| 357 |
+
"\n",
|
| 358 |
+
"class DiscriminatorComponentsDict(TypedDict):\n",
|
| 359 |
+
" \"\"\"Components of the discriminator loss\"\"\"\n",
|
| 360 |
+
" real_uncond: float\n",
|
| 361 |
+
" real_cond: float\n",
|
| 362 |
+
" real_cond_wrong: float\n",
|
| 363 |
+
" fake_uncond: float\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"class ValidationResultsDict(TypedDict):\n",
|
| 366 |
+
" \"\"\"Single losses for validation\"\"\"\n",
|
| 367 |
+
" l1: float\n",
|
| 368 |
+
" perceptual: float\n",
|
| 369 |
+
" sobel: float\n",
|
| 370 |
+
" total: float\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"# Training history\n",
|
| 373 |
+
"losses: LossesDict = {\n",
|
| 374 |
+
" 'generator': [],\n",
|
| 375 |
+
" 'discriminator': [],\n",
|
| 376 |
+
" 'l1': [],\n",
|
| 377 |
+
" 'perceptual': [],\n",
|
| 378 |
+
" 'sobel': [],\n",
|
| 379 |
+
"}\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"# Validation history\n",
|
| 382 |
+
"val_losses: ValidationLossesDict = {\n",
|
| 383 |
+
" 'l1': [],\n",
|
| 384 |
+
" 'perceptual': [],\n",
|
| 385 |
+
" 'sobel': [],\n",
|
| 386 |
+
" 'total': [],\n",
|
| 387 |
+
"}\n",
|
| 388 |
+
"\n",
|
| 389 |
+
"def validate_model(generator, val_loader, device, l1_criterion, perc_criterion, sobel_criterion) -> ValidationResultsDict:\n",
|
| 390 |
+
" \"\"\"\n",
|
| 391 |
+
" Validate the model on the validation set\n",
|
| 392 |
+
" Returns validation losses\n",
|
| 393 |
+
" \"\"\"\n",
|
| 394 |
+
" generator.eval()\n",
|
| 395 |
+
"\n",
|
| 396 |
+
" val_l1_loss = 0.0\n",
|
| 397 |
+
" val_perc_loss = 0.0\n",
|
| 398 |
+
" val_sobel_loss = 0.0\n",
|
| 399 |
+
" num_batches = 0\n",
|
| 400 |
+
"\n",
|
| 401 |
+
" with torch.no_grad():\n",
|
| 402 |
+
" for batch in val_loader:\n",
|
| 403 |
+
" # Move data to device\n",
|
| 404 |
+
" real_images = batch['image'].to(device)\n",
|
| 405 |
+
" text_ids = batch['text'].to(device)\n",
|
| 406 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
| 407 |
+
"\n",
|
| 408 |
+
" # Generate images\n",
|
| 409 |
+
" generated_images, _ = generator(text_ids, attention_mask)\n",
|
| 410 |
+
"\n",
|
| 411 |
+
" # Calculate validation losses (no adversarial loss)\n",
|
| 412 |
+
" batch_l1_loss = l1_criterion(generated_images, real_images)\n",
|
| 413 |
+
" batch_perc_loss = perc_criterion(generated_images, real_images)\n",
|
| 414 |
+
" batch_sobel_loss = sobel_criterion(generated_images, real_images)\n",
|
| 415 |
+
"\n",
|
| 416 |
+
" val_l1_loss += batch_l1_loss.item()\n",
|
| 417 |
+
" val_perc_loss += batch_perc_loss.item()\n",
|
| 418 |
+
" val_sobel_loss += batch_sobel_loss.item()\n",
|
| 419 |
+
" num_batches += 1\n",
|
| 420 |
+
"\n",
|
| 421 |
+
" # Calculate averages\n",
|
| 422 |
+
" avg_val_l1 = val_l1_loss / num_batches\n",
|
| 423 |
+
" avg_val_perc = val_perc_loss / num_batches\n",
|
| 424 |
+
" avg_val_sobel = val_sobel_loss / num_batches\n",
|
| 425 |
+
" avg_val_total = avg_val_l1 + avg_val_perc + avg_val_sobel\n",
|
| 426 |
+
"\n",
|
| 427 |
+
" # Set models back to training mode\n",
|
| 428 |
+
" generator.train()\n",
|
| 429 |
+
"\n",
|
| 430 |
+
" return ValidationResultsDict(\n",
|
| 431 |
+
" l1=avg_val_l1,\n",
|
| 432 |
+
" perceptual=avg_val_perc,\n",
|
| 433 |
+
" sobel=avg_val_sobel,\n",
|
| 434 |
+
" total=avg_val_total\n",
|
| 435 |
+
" )\n",
|
| 436 |
+
"\n",
|
| 437 |
+
"def create_mismatched_text_batch(text_ids, attention_mask):\n",
|
| 438 |
+
" \"\"\"Create a batch with mismatched text for wrong text conditioning\"\"\"\n",
|
| 439 |
+
" batch_size = text_ids.size(0)\n",
|
| 440 |
+
" indices = torch.randperm(batch_size)\n",
|
| 441 |
+
" return text_ids[indices], attention_mask[indices]\n",
|
| 442 |
+
"\n",
|
| 443 |
+
"def compute_discriminator_loss(\n",
|
| 444 |
+
" discriminator,\n",
|
| 445 |
+
" real_images,\n",
|
| 446 |
+
" fake_images,\n",
|
| 447 |
+
" text_ids,\n",
|
| 448 |
+
" attention_mask,\n",
|
| 449 |
+
" wrong_text_ids,\n",
|
| 450 |
+
" wrong_attention_mask,\n",
|
| 451 |
+
" real_labels,\n",
|
| 452 |
+
" fake_labels,\n",
|
| 453 |
+
" adv_criterion\n",
|
| 454 |
+
") -> tuple[torch.Tensor, DiscriminatorComponentsDict]:\n",
|
| 455 |
+
" \"\"\"Compute discriminator loss with the 4 components.\n",
|
| 456 |
+
" Returns the total loss and the 4 components.\"\"\"\n",
|
| 457 |
+
" # Real images with correct text\n",
|
| 458 |
+
" real_uncond, real_cond = discriminator(real_images, text_ids, attention_mask, return_both=True)\n",
|
| 459 |
+
" real_uncond_loss = adv_criterion(real_uncond, real_labels)\n",
|
| 460 |
+
" real_cond_loss = adv_criterion(real_cond, real_labels)\n",
|
| 461 |
+
"\n",
|
| 462 |
+
" # Real images with wrong text\n",
|
| 463 |
+
" _, real_cond_wrong = discriminator(real_images, wrong_text_ids, wrong_attention_mask, return_both=True)\n",
|
| 464 |
+
" real_cond_wrong_loss = adv_criterion(real_cond_wrong, fake_labels)\n",
|
| 465 |
+
"\n",
|
| 466 |
+
" # Fake images with wrong text\n",
|
| 467 |
+
" fake_uncond, _ = discriminator(fake_images.detach(), wrong_text_ids, wrong_attention_mask, return_both=True)\n",
|
| 468 |
+
" fake_uncond_loss = adv_criterion(fake_uncond, fake_labels)\n",
|
| 469 |
+
"\n",
|
| 470 |
+
" total_loss = (real_uncond_loss + real_cond_loss + real_cond_wrong_loss + fake_uncond_loss) / 4\n",
|
| 471 |
+
"\n",
|
| 472 |
+
" components: DiscriminatorComponentsDict = {\n",
|
| 473 |
+
" 'real_uncond': real_uncond_loss.item(),\n",
|
| 474 |
+
" 'real_cond': real_cond_loss.item(),\n",
|
| 475 |
+
" 'real_cond_wrong': real_cond_wrong_loss.item(),\n",
|
| 476 |
+
" 'fake_uncond': fake_uncond_loss.item(),\n",
|
| 477 |
+
" }\n",
|
| 478 |
+
"\n",
|
| 479 |
+
" return total_loss, components\n",
|
| 480 |
+
"\n",
|
| 481 |
+
"def compute_generator_adversarial_loss(\n",
|
| 482 |
+
" discriminator,\n",
|
| 483 |
+
" fake_images,\n",
|
| 484 |
+
" text_ids,\n",
|
| 485 |
+
" attention_mask,\n",
|
| 486 |
+
" real_labels,\n",
|
| 487 |
+
" adv_criterion\n",
|
| 488 |
+
") -> torch.Tensor:\n",
|
| 489 |
+
" \"\"\"Compute generator adversarial loss for one discriminator\"\"\"\n",
|
| 490 |
+
" fake_uncond, fake_cond = discriminator(fake_images, text_ids, attention_mask, return_both=True)\n",
|
| 491 |
+
" uncond_loss = adv_criterion(fake_uncond, real_labels)\n",
|
| 492 |
+
" cond_loss = adv_criterion(fake_cond, real_labels)\n",
|
| 493 |
+
" return (uncond_loss + cond_loss) / 2\n",
|
| 494 |
+
"\n",
|
| 495 |
+
"def compute_loss(fake_images, real_images, criterion, lmd):\n",
|
| 496 |
+
" \"\"\"Compute a reconstruction loss only if its lambda > 0\"\"\"\n",
|
| 497 |
+
" return criterion(fake_images, real_images) if lmd > 0 else torch.tensor(0.0, device=device)\n",
|
| 498 |
+
"\n",
|
| 499 |
+
"\n",
|
| 500 |
+
"epoch = 0\n",
|
| 501 |
+
"noise_dim = 100\n"
|
| 502 |
+
]
|
| 503 |
+
},
|
| 504 |
+
{
|
| 505 |
+
"cell_type": "raw",
|
| 506 |
+
"metadata": {
|
| 507 |
+
"id": "Oenm8AkasnNt",
|
| 508 |
+
"vscode": {
|
| 509 |
+
"languageId": "raw"
|
| 510 |
+
}
|
| 511 |
+
},
|
| 512 |
+
"source": [
|
| 513 |
+
"## 4. GAN Training Loop"
|
| 514 |
+
]
|
| 515 |
+
},
|
| 516 |
+
{
|
| 517 |
+
"cell_type": "code",
|
| 518 |
+
"execution_count": null,
|
| 519 |
+
"metadata": {
|
| 520 |
+
"colab": {
|
| 521 |
+
"base_uri": "https://localhost:8080/",
|
| 522 |
+
"height": 1000
|
| 523 |
+
},
|
| 524 |
+
"id": "gmo0Mi6osnNt",
|
| 525 |
+
"outputId": "a4691684-def3-4d29-d00d-20ae40287c8c"
|
| 526 |
+
},
|
| 527 |
+
"outputs": [],
|
| 528 |
+
"source": [
|
| 529 |
+
"from IPython.display import clear_output\n",
|
| 530 |
+
"\n",
|
| 531 |
+
"total_epochs = 150\n",
|
| 532 |
+
"display_interval = 1 # To show generation of training sample\n",
|
| 533 |
+
"save_interval = 15 # To save checkpoint\n",
|
| 534 |
+
"clear_interval = 22 # To clear cell output. If too high or not present, Kaggle page would crash\n",
|
| 535 |
+
"\n",
|
| 536 |
+
"lambda_l1 = 1.0\n",
|
| 537 |
+
"lambda_adv = 1.0\n",
|
| 538 |
+
"lambda_perceptual = 0.0\n",
|
| 539 |
+
"lambda_sobel = 0.0\n",
|
| 540 |
+
"\n",
|
| 541 |
+
"real_label = 1.0\n",
|
| 542 |
+
"fake_label = 0.0\n",
|
| 543 |
+
"\n",
|
| 544 |
+
"print(\"Starting training with dual discriminators...\")\n",
|
| 545 |
+
"\n",
|
| 546 |
+
"for epoch in range(epoch, total_epochs):\n",
|
| 547 |
+
" epoch_g_loss = 0.0\n",
|
| 548 |
+
" epoch_d_loss_64 = 0.0\n",
|
| 549 |
+
" epoch_d_loss_256 = 0.0\n",
|
| 550 |
+
" epoch_l1_loss = 0.0\n",
|
| 551 |
+
" epoch_perc_loss = 0.0\n",
|
| 552 |
+
" epoch_sobel_loss = 0.0\n",
|
| 553 |
+
"\n",
|
| 554 |
+
" # Track discriminator loss components\n",
|
| 555 |
+
" epoch_d256_components: DiscriminatorComponentsDict = {'real_uncond': 0.0, 'real_cond': 0.0, 'real_cond_wrong': 0.0, 'fake_uncond': 0.0}\n",
|
| 556 |
+
" epoch_d64_components: DiscriminatorComponentsDict = {'real_uncond': 0.0, 'real_cond': 0.0, 'real_cond_wrong': 0.0, 'fake_uncond': 0.0}\n",
|
| 557 |
+
"\n",
|
| 558 |
+
" progress_bar = tqdm(train_loader, desc=f\"Epoch {epoch+1}/{total_epochs}\")\n",
|
| 559 |
+
"\n",
|
| 560 |
+
" for i, batch in enumerate(progress_bar):\n",
|
| 561 |
+
" batch_size = batch['image'].size(0)\n",
|
| 562 |
+
"\n",
|
| 563 |
+
" # Move data to device\n",
|
| 564 |
+
" real_images = batch['image'].to(device)\n",
|
| 565 |
+
" text_ids = batch['text'].to(device)\n",
|
| 566 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
| 567 |
+
"\n",
|
| 568 |
+
" # Create labels and mismatched text for GAN training\n",
|
| 569 |
+
" real_labels = torch.full((batch_size, 1), real_label, device=device, dtype=torch.float)\n",
|
| 570 |
+
" fake_labels = torch.full((batch_size, 1), fake_label, device=device, dtype=torch.float)\n",
|
| 571 |
+
" wrong_text_ids, wrong_attention_mask = create_mismatched_text_batch(text_ids, attention_mask)\n",
|
| 572 |
+
"\n",
|
| 573 |
+
" # Generate fake images\n",
|
| 574 |
+
" fake_images_256, fake_images_64 = generator(text_ids, attention_mask)\n",
|
| 575 |
+
" real_images_64 = F.interpolate(real_images, size=(64, 64), mode='bilinear', align_corners=False)\n",
|
| 576 |
+
"\n",
|
| 577 |
+
" # Training both discriminators\n",
|
| 578 |
+
" optimizer_D_256.zero_grad()\n",
|
| 579 |
+
" optimizer_D_64.zero_grad()\n",
|
| 580 |
+
"\n",
|
| 581 |
+
" d_loss_256, d256_components = compute_discriminator_loss(\n",
|
| 582 |
+
" discriminator_256, real_images, fake_images_256,\n",
|
| 583 |
+
" text_ids, attention_mask, wrong_text_ids, wrong_attention_mask,\n",
|
| 584 |
+
" real_labels, fake_labels, adv_criterion\n",
|
| 585 |
+
" )\n",
|
| 586 |
+
" d_loss_256.backward()\n",
|
| 587 |
+
"\n",
|
| 588 |
+
" d_loss_64, d64_components = compute_discriminator_loss(\n",
|
| 589 |
+
" discriminator_64, real_images_64, fake_images_64,\n",
|
| 590 |
+
" text_ids, attention_mask, wrong_text_ids, wrong_attention_mask,\n",
|
| 591 |
+
" real_labels, fake_labels, adv_criterion\n",
|
| 592 |
+
" )\n",
|
| 593 |
+
" d_loss_64.backward()\n",
|
| 594 |
+
"\n",
|
| 595 |
+
" optimizer_D_256.step()\n",
|
| 596 |
+
" optimizer_D_64.step()\n",
|
| 597 |
+
"\n",
|
| 598 |
+
" # Training generator\n",
|
| 599 |
+
" optimizer_G.zero_grad()\n",
|
| 600 |
+
"\n",
|
| 601 |
+
" # Adversarial losses for both discriminators\n",
|
| 602 |
+
" g_adv_loss_256 = compute_generator_adversarial_loss(\n",
|
| 603 |
+
" discriminator_256, fake_images_256, text_ids, attention_mask, real_labels, adv_criterion\n",
|
| 604 |
+
" )\n",
|
| 605 |
+
" g_adv_loss_64 = compute_generator_adversarial_loss(\n",
|
| 606 |
+
" discriminator_64, fake_images_64, text_ids, attention_mask, real_labels, adv_criterion\n",
|
| 607 |
+
" )\n",
|
| 608 |
+
" adversarial_loss = (g_adv_loss_256 + g_adv_loss_64) / 2\n",
|
| 609 |
+
"\n",
|
| 610 |
+
" # Compute losses if their lambda is > 0\n",
|
| 611 |
+
" l1_loss = compute_loss(fake_images_256, real_images, l1_criterion, lambda_l1)\n",
|
| 612 |
+
" perc_loss = compute_loss(fake_images_256, real_images, perc_criterion, lambda_perceptual)\n",
|
| 613 |
+
" sobel_loss = compute_loss(fake_images_256, real_images, sobel_criterion, lambda_sobel)\n",
|
| 614 |
+
"\n",
|
| 615 |
+
" # Total generator loss\n",
|
| 616 |
+
" g_loss = (\n",
|
| 617 |
+
" lambda_adv * adversarial_loss +\n",
|
| 618 |
+
" lambda_l1 * l1_loss +\n",
|
| 619 |
+
" lambda_perceptual * perc_loss +\n",
|
| 620 |
+
" lambda_sobel * sobel_loss\n",
|
| 621 |
+
" )\n",
|
| 622 |
+
" g_loss.backward()\n",
|
| 623 |
+
" optimizer_G.step()\n",
|
| 624 |
+
"\n",
|
| 625 |
+
" # Update loss tracking\n",
|
| 626 |
+
" epoch_g_loss += g_loss.item()\n",
|
| 627 |
+
" epoch_d_loss_256 += d_loss_256.item()\n",
|
| 628 |
+
" epoch_d_loss_64 += d_loss_64.item()\n",
|
| 629 |
+
" epoch_l1_loss += l1_loss.item()\n",
|
| 630 |
+
" epoch_perc_loss += perc_loss.item()\n",
|
| 631 |
+
" epoch_sobel_loss += sobel_loss.item()\n",
|
| 632 |
+
"\n",
|
| 633 |
+
" # Update discriminator component tracking\n",
|
| 634 |
+
" for key in epoch_d256_components:\n",
|
| 635 |
+
" epoch_d256_components[key] += d256_components[key]\n",
|
| 636 |
+
" epoch_d64_components[key] += d64_components[key]\n",
|
| 637 |
+
"\n",
|
| 638 |
+
" # Update progress bar with detailed losses and loss components\n",
|
| 639 |
+
" progress_bar.set_postfix({\n",
|
| 640 |
+
" 'G': f'{g_loss.item():.3f}',\n",
|
| 641 |
+
" 'L1': f'{l1_loss.item():.3f}',\n",
|
| 642 |
+
" 'Adv': f'{adversarial_loss.item():.3f}',\n",
|
| 643 |
+
" 'D256': f'{d_loss_256.item():.3f}',\n",
|
| 644 |
+
" 'D256_ru': f'{d256_components[\"real_uncond\"]:.3f}',\n",
|
| 645 |
+
" 'D256_rc': f'{d256_components[\"real_cond\"]:.3f}',\n",
|
| 646 |
+
" 'D256_rcw': f'{d256_components[\"real_cond_wrong\"]:.3f}',\n",
|
| 647 |
+
" 'D256_fu': f'{d256_components[\"fake_uncond\"]:.3f}',\n",
|
| 648 |
+
" 'D64': f'{d_loss_64.item():.3f}',\n",
|
| 649 |
+
" 'D64_ru': f'{d64_components[\"real_uncond\"]:.3f}',\n",
|
| 650 |
+
" 'D64_rc': f'{d64_components[\"real_cond\"]:.3f}',\n",
|
| 651 |
+
" 'D64_rcw': f'{d64_components[\"real_cond_wrong\"]:.3f}',\n",
|
| 652 |
+
" 'D64_fu': f'{d64_components[\"fake_uncond\"]:.3f}',\n",
|
| 653 |
+
" })\n",
|
| 654 |
+
"\n",
|
| 655 |
+
" # Calculate average losses for the epoch\n",
|
| 656 |
+
" avg_g_loss = epoch_g_loss / len(train_loader)\n",
|
| 657 |
+
" avg_d_loss_256 = epoch_d_loss_256 / len(train_loader)\n",
|
| 658 |
+
" avg_d_loss_64 = epoch_d_loss_64 / len(train_loader)\n",
|
| 659 |
+
" avg_l1_loss = epoch_l1_loss / len(train_loader)\n",
|
| 660 |
+
" avg_perc_loss = epoch_perc_loss / len(train_loader)\n",
|
| 661 |
+
" avg_sobel_loss = epoch_sobel_loss / len(train_loader)\n",
|
| 662 |
+
"\n",
|
| 663 |
+
" # Calculate average discriminator components for epoch\n",
|
| 664 |
+
" avg_d256_components = {key: val / len(train_loader) for key, val in epoch_d256_components.items()}\n",
|
| 665 |
+
" avg_d64_components = {key: val / len(train_loader) for key, val in epoch_d64_components.items()}\n",
|
| 666 |
+
"\n",
|
| 667 |
+
" # Store losses (combine discriminator losses)\n",
|
| 668 |
+
" losses['generator'].append(avg_g_loss)\n",
|
| 669 |
+
" losses['discriminator'].append((avg_d_loss_256 + avg_d_loss_64) / 2)\n",
|
| 670 |
+
" losses['l1'].append(avg_l1_loss)\n",
|
| 671 |
+
" losses['perceptual'].append(avg_perc_loss)\n",
|
| 672 |
+
" losses['sobel'].append(avg_sobel_loss)\n",
|
| 673 |
+
"\n",
|
| 674 |
+
" print(f\"Running validation for epoch {epoch+1}...\")\n",
|
| 675 |
+
" validation_results = validate_model(generator, val_loader, device, l1_criterion, perc_criterion, sobel_criterion)\n",
|
| 676 |
+
"\n",
|
| 677 |
+
" # Store validation losses\n",
|
| 678 |
+
"\n",
|
| 679 |
+
" for k, v in validation_results.items():\n",
|
| 680 |
+
" val_losses[k].append(v)\n",
|
| 681 |
+
"\n",
|
| 682 |
+
" if (epoch + 1) % clear_interval == 0:\n",
|
| 683 |
+
" clear_output(wait=True)\n",
|
| 684 |
+
"\n",
|
| 685 |
+
" print(f\"Epoch [{epoch+1}/{total_epochs}]\")\n",
|
| 686 |
+
" print(f\" Train - D_256: {avg_d_loss_256:.4f}, D_64: {avg_d_loss_64:.4f}, G_loss: {avg_g_loss:.4f}\")\n",
|
| 687 |
+
" print(f\" D_256 Components - RU: {avg_d256_components['real_uncond']:.4f}, RC: {avg_d256_components['real_cond']:.4f}, RCW: {avg_d256_components['real_cond_wrong']:.4f}, FU: {avg_d256_components['fake_uncond']:.4f}\")\n",
|
| 688 |
+
" print(f\" D_64 Components - RU: {avg_d64_components['real_uncond']:.4f}, RC: {avg_d64_components['real_cond']:.4f}, RCW: {avg_d64_components['real_cond_wrong']:.4f}, FU: {avg_d64_components['fake_uncond']:.4f}\")\n",
|
| 689 |
+
" print(f\" Train - L1: {avg_l1_loss:.4f}, Perceptual: {avg_perc_loss:.4f}, Sobel: {avg_sobel_loss:.4f}\")\n",
|
| 690 |
+
" print(f\" Val - L1: {validation_results['l1']:.4f}, Perceptual: {validation_results['perceptual']:.4f}, Sobel: {validation_results['sobel']:.4f}, Total: {validation_results['total']:.4f}\")\n",
|
| 691 |
+
" print(\" Legend: RU=RealUncond, RC=RealCond, RCW=RealCondWrong, FU=FakeUncond\")\n",
|
| 692 |
+
"\n",
|
| 693 |
+
" # Display generated images\n",
|
| 694 |
+
" if (epoch + 1) % display_interval == 0:\n",
|
| 695 |
+
" print(f\"\\nGenerating sample images at epoch {epoch+1}:\")\n",
|
| 696 |
+
" print(\"256x256 Training Images:\")\n",
|
| 697 |
+
" save_comparison_grid(epoch+1, generator, fixed_train_batch, \"train_256\", device, show_inline=True)\n",
|
| 698 |
+
" print(\"64x64 Training Images:\")\n",
|
| 699 |
+
" save_comparison_grid(epoch+1, generator, fixed_train_batch, \"train_64\", device, show_inline=True)\n",
|
| 700 |
+
"\n",
|
| 701 |
+
" # Save checkpoint and show visualizations\n",
|
| 702 |
+
" if (epoch + 1) % save_interval == 0:\n",
|
| 703 |
+
" checkpoint_path = f'models/checkpoint_epoch_{epoch+1}.pth'\n",
|
| 704 |
+
" all_losses = {'train': losses, 'val': val_losses}\n",
|
| 705 |
+
" checkpoint = {\n",
|
| 706 |
+
" 'epoch': epoch,\n",
|
| 707 |
+
" 'generator_state_dict': generator.state_dict(),\n",
|
| 708 |
+
" 'discriminator_256_state_dict': discriminator_256.state_dict(),\n",
|
| 709 |
+
" 'discriminator_64_state_dict': discriminator_64.state_dict(),\n",
|
| 710 |
+
" 'g_optimizer_state_dict': optimizer_G.state_dict(),\n",
|
| 711 |
+
" 'd_optimizer_state_dict': optimizer_D_256.state_dict(),\n",
|
| 712 |
+
" 'd_64_optimizer_state_dict': optimizer_D_64.state_dict(),\n",
|
| 713 |
+
" 'losses': all_losses\n",
|
| 714 |
+
" }\n",
|
| 715 |
+
" torch.save(checkpoint, checkpoint_path)\n",
|
| 716 |
+
" print(f\"Checkpoint saved to {checkpoint_path}\")\n",
|
| 717 |
+
"\n",
|
| 718 |
+
" print(\"256x256 Validation Images:\")\n",
|
| 719 |
+
" save_comparison_grid(epoch+1, generator, fixed_val_batch, \"val_256\", device, show_inline=True)\n",
|
| 720 |
+
" print(\"64x64 Validation Images:\")\n",
|
| 721 |
+
" save_comparison_grid(epoch+1, generator, fixed_val_batch, \"val_64\", device, show_inline=True)\n",
|
| 722 |
+
" save_attention_visualization(epoch+1, generator, tokenizer, fixed_train_batch, device, \"train\", show_inline=True)\n",
|
| 723 |
+
" save_attention_visualization(epoch+1, generator, tokenizer, fixed_val_batch, device, \"val\", show_inline=True)\n",
|
| 724 |
+
"\n",
|
| 725 |
+
"print(\"Training completed!\")\n"
|
| 726 |
+
]
|
| 727 |
+
},
|
| 728 |
+
{
|
| 729 |
+
"cell_type": "raw",
|
| 730 |
+
"metadata": {
|
| 731 |
+
"id": "rbv1Wz4csnNu",
|
| 732 |
+
"vscode": {
|
| 733 |
+
"languageId": "raw"
|
| 734 |
+
}
|
| 735 |
+
},
|
| 736 |
+
"source": [
|
| 737 |
+
"## 5. Training Analysis and Visualization\n"
|
| 738 |
+
]
|
| 739 |
+
},
|
| 740 |
+
{
|
| 741 |
+
"cell_type": "code",
|
| 742 |
+
"execution_count": null,
|
| 743 |
+
"metadata": {
|
| 744 |
+
"id": "l_90zE2CsnNu"
|
| 745 |
+
},
|
| 746 |
+
"outputs": [],
|
| 747 |
+
"source": [
|
| 748 |
+
"from plots import save_plot_losses, save_plot_non_gan_losses\n",
|
| 749 |
+
"\n",
|
| 750 |
+
"\n",
|
| 751 |
+
"save_plot_losses(\n",
|
| 752 |
+
" losses_g=losses['generator'],\n",
|
| 753 |
+
" losses_d=losses['discriminator'],\n",
|
| 754 |
+
" output_dir=\"training_output\",\n",
|
| 755 |
+
" show_inline=True\n",
|
| 756 |
+
")\n",
|
| 757 |
+
"\n",
|
| 758 |
+
"# Plot training vs validation losses for non-GAN components (so except \"generator\" and \"discriminator\" from losses)\n",
|
| 759 |
+
"# Convert to list of dicts format expected by save_plot_non_gan_losses\n",
|
| 760 |
+
"train_losses_history = []\n",
|
| 761 |
+
"val_losses_history = []\n",
|
| 762 |
+
"\n",
|
| 763 |
+
"for i in range(len(losses['l1'])):\n",
|
| 764 |
+
" train_losses_history.append({\n",
|
| 765 |
+
" 'l1': losses['l1'][i],\n",
|
| 766 |
+
" 'perceptual': losses['perceptual'][i],\n",
|
| 767 |
+
" 'sobel': losses['sobel'][i],\n",
|
| 768 |
+
" 'total': losses['l1'][i] + losses['perceptual'][i] + losses['sobel'][i]\n",
|
| 769 |
+
" })\n",
|
| 770 |
+
"\n",
|
| 771 |
+
"for i in range(len(val_losses['l1'])):\n",
|
| 772 |
+
" val_losses_history.append({\n",
|
| 773 |
+
" 'l1': val_losses['l1'][i],\n",
|
| 774 |
+
" 'perceptual': val_losses['perceptual'][i],\n",
|
| 775 |
+
" 'sobel': val_losses['sobel'][i],\n",
|
| 776 |
+
" 'total': val_losses['total'][i]\n",
|
| 777 |
+
" })\n",
|
| 778 |
+
"\n",
|
| 779 |
+
"save_plot_non_gan_losses(\n",
|
| 780 |
+
" train_losses_history=train_losses_history,\n",
|
| 781 |
+
" val_losses_history=val_losses_history,\n",
|
| 782 |
+
" output_dir=\"training_output\",\n",
|
| 783 |
+
" show_inline=True\n",
|
| 784 |
+
")\n",
|
| 785 |
+
"\n",
|
| 786 |
+
"# Print final statistics\n",
|
| 787 |
+
"print(f\"Final Train - Generator Loss: {losses['generator'][-1]:.4f}\")\n",
|
| 788 |
+
"print(f\"Final Train - Discriminator Loss: {losses['discriminator'][-1]:.4f}\")\n",
|
| 789 |
+
"print(f\"Final Train - L1 Loss: {losses['l1'][-1]:.4f}\")\n",
|
| 790 |
+
"print(f\"Final Train - Perceptual Loss: {losses['perceptual'][-1]:.4f}\")\n",
|
| 791 |
+
"print(f\"Final Train - Sobel Loss: {losses['sobel'][-1]:.4f}\")\n",
|
| 792 |
+
"\n",
|
| 793 |
+
"print(f\"Final Val - L1 Loss: {val_losses['l1'][-1]:.4f}\")\n",
|
| 794 |
+
"print(f\"Final Val - Perceptual Loss: {val_losses['perceptual'][-1]:.4f}\")\n",
|
| 795 |
+
"print(f\"Final Val - Sobel Loss: {val_losses['sobel'][-1]:.4f}\")\n",
|
| 796 |
+
"print(f\"Final Val - Total Loss: {val_losses['total'][-1]:.4f}\")\n"
|
| 797 |
+
]
|
| 798 |
+
},
|
| 799 |
+
{
|
| 800 |
+
"cell_type": "code",
|
| 801 |
+
"execution_count": null,
|
| 802 |
+
"metadata": {
|
| 803 |
+
"id": "Io7I7RTqsnNu"
|
| 804 |
+
},
|
| 805 |
+
"outputs": [],
|
| 806 |
+
"source": [
|
| 807 |
+
"# Generate a grid of final results\n",
|
| 808 |
+
"print(\"Final Results - Generated Pokemon Sprites (256x256):\")\n",
|
| 809 |
+
"batch = next(iter(train_loader))\n",
|
| 810 |
+
"save_comparison_grid(0, generator, batch, \"final_256\", device, show_inline=True)\n",
|
| 811 |
+
"\n",
|
| 812 |
+
"print(\"Final Results - Generated Pokemon Sprites (64x64):\")\n",
|
| 813 |
+
"save_comparison_grid(0, generator, batch, \"final_64\", device, show_inline=True)\n"
|
| 814 |
+
]
|
| 815 |
+
},
|
| 816 |
+
{
|
| 817 |
+
"cell_type": "raw",
|
| 818 |
+
"metadata": {
|
| 819 |
+
"id": "3a_jxGvCsnNu",
|
| 820 |
+
"vscode": {
|
| 821 |
+
"languageId": "raw"
|
| 822 |
+
}
|
| 823 |
+
},
|
| 824 |
+
"source": [
|
| 825 |
+
"## 7. Model Analysis\n"
|
| 826 |
+
]
|
| 827 |
+
},
|
| 828 |
+
{
|
| 829 |
+
"cell_type": "code",
|
| 830 |
+
"execution_count": null,
|
| 831 |
+
"metadata": {},
|
| 832 |
+
"outputs": [],
|
| 833 |
+
"source": [
|
| 834 |
+
"def count_parameters(model):\n",
|
| 835 |
+
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
| 836 |
+
"\n",
|
| 837 |
+
"print(f\"Generator parameters: {count_parameters(generator):,}\")\n",
|
| 838 |
+
"print(f\"Discriminator (256) parameters: {count_parameters(discriminator_256):,}\")\n",
|
| 839 |
+
"print(f\"Discriminator (64) parameters: {count_parameters(discriminator_64):,}\")\n"
|
| 840 |
+
]
|
| 841 |
+
}
|
| 842 |
+
],
|
| 843 |
+
"metadata": {
|
| 844 |
+
"accelerator": "GPU",
|
| 845 |
+
"colab": {
|
| 846 |
+
"gpuType": "T4",
|
| 847 |
+
"provenance": []
|
| 848 |
+
},
|
| 849 |
+
"kernelspec": {
|
| 850 |
+
"display_name": "Python 3 (ipykernel)",
|
| 851 |
+
"language": "python",
|
| 852 |
+
"name": "python3"
|
| 853 |
+
},
|
| 854 |
+
"language_info": {
|
| 855 |
+
"codemirror_mode": {
|
| 856 |
+
"name": "ipython",
|
| 857 |
+
"version": 3
|
| 858 |
+
},
|
| 859 |
+
"file_extension": ".py",
|
| 860 |
+
"mimetype": "text/x-python",
|
| 861 |
+
"name": "python",
|
| 862 |
+
"nbconvert_exporter": "python",
|
| 863 |
+
"pygments_lexer": "ipython3",
|
| 864 |
+
"version": "3.12.11"
|
| 865 |
+
},
|
| 866 |
+
"widgets": {
|
| 867 |
+
"application/vnd.jupyter.widget-state+json": {
|
| 868 |
+
"128f4312bcdc4166b9e24d8cdd34184d": {
|
| 869 |
+
"model_module": "@jupyter-widgets/base",
|
| 870 |
+
"model_module_version": "1.2.0",
|
| 871 |
+
"model_name": "LayoutModel",
|
| 872 |
+
"state": {
|
| 873 |
+
"_model_module": "@jupyter-widgets/base",
|
| 874 |
+
"_model_module_version": "1.2.0",
|
| 875 |
+
"_model_name": "LayoutModel",
|
| 876 |
+
"_view_count": null,
|
| 877 |
+
"_view_module": "@jupyter-widgets/base",
|
| 878 |
+
"_view_module_version": "1.2.0",
|
| 879 |
+
"_view_name": "LayoutView",
|
| 880 |
+
"align_content": null,
|
| 881 |
+
"align_items": null,
|
| 882 |
+
"align_self": null,
|
| 883 |
+
"border": null,
|
| 884 |
+
"bottom": null,
|
| 885 |
+
"display": null,
|
| 886 |
+
"flex": null,
|
| 887 |
+
"flex_flow": null,
|
| 888 |
+
"grid_area": null,
|
| 889 |
+
"grid_auto_columns": null,
|
| 890 |
+
"grid_auto_flow": null,
|
| 891 |
+
"grid_auto_rows": null,
|
| 892 |
+
"grid_column": null,
|
| 893 |
+
"grid_gap": null,
|
| 894 |
+
"grid_row": null,
|
| 895 |
+
"grid_template_areas": null,
|
| 896 |
+
"grid_template_columns": null,
|
| 897 |
+
"grid_template_rows": null,
|
| 898 |
+
"height": null,
|
| 899 |
+
"justify_content": null,
|
| 900 |
+
"justify_items": null,
|
| 901 |
+
"left": null,
|
| 902 |
+
"margin": null,
|
| 903 |
+
"max_height": null,
|
| 904 |
+
"max_width": null,
|
| 905 |
+
"min_height": null,
|
| 906 |
+
"min_width": null,
|
| 907 |
+
"object_fit": null,
|
| 908 |
+
"object_position": null,
|
| 909 |
+
"order": null,
|
| 910 |
+
"overflow": null,
|
| 911 |
+
"overflow_x": null,
|
| 912 |
+
"overflow_y": null,
|
| 913 |
+
"padding": null,
|
| 914 |
+
"right": null,
|
| 915 |
+
"top": null,
|
| 916 |
+
"visibility": null,
|
| 917 |
+
"width": null
|
| 918 |
+
}
|
| 919 |
+
},
|
| 920 |
+
"1b65d6c8540e4f458886d5e7075ab30a": {
|
| 921 |
+
"model_module": "@jupyter-widgets/controls",
|
| 922 |
+
"model_module_version": "1.5.0",
|
| 923 |
+
"model_name": "DescriptionStyleModel",
|
| 924 |
+
"state": {
|
| 925 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 926 |
+
"_model_module_version": "1.5.0",
|
| 927 |
+
"_model_name": "DescriptionStyleModel",
|
| 928 |
+
"_view_count": null,
|
| 929 |
+
"_view_module": "@jupyter-widgets/base",
|
| 930 |
+
"_view_module_version": "1.2.0",
|
| 931 |
+
"_view_name": "StyleView",
|
| 932 |
+
"description_width": ""
|
| 933 |
+
}
|
| 934 |
+
},
|
| 935 |
+
"2fe9614fe5984fa6b887d1e1b3e18b04": {
|
| 936 |
+
"model_module": "@jupyter-widgets/controls",
|
| 937 |
+
"model_module_version": "1.5.0",
|
| 938 |
+
"model_name": "ProgressStyleModel",
|
| 939 |
+
"state": {
|
| 940 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 941 |
+
"_model_module_version": "1.5.0",
|
| 942 |
+
"_model_name": "ProgressStyleModel",
|
| 943 |
+
"_view_count": null,
|
| 944 |
+
"_view_module": "@jupyter-widgets/base",
|
| 945 |
+
"_view_module_version": "1.2.0",
|
| 946 |
+
"_view_name": "StyleView",
|
| 947 |
+
"bar_color": null,
|
| 948 |
+
"description_width": ""
|
| 949 |
+
}
|
| 950 |
+
},
|
| 951 |
+
"304d50e74ad744cdb3a7cc88739cb923": {
|
| 952 |
+
"model_module": "@jupyter-widgets/base",
|
| 953 |
+
"model_module_version": "1.2.0",
|
| 954 |
+
"model_name": "LayoutModel",
|
| 955 |
+
"state": {
|
| 956 |
+
"_model_module": "@jupyter-widgets/base",
|
| 957 |
+
"_model_module_version": "1.2.0",
|
| 958 |
+
"_model_name": "LayoutModel",
|
| 959 |
+
"_view_count": null,
|
| 960 |
+
"_view_module": "@jupyter-widgets/base",
|
| 961 |
+
"_view_module_version": "1.2.0",
|
| 962 |
+
"_view_name": "LayoutView",
|
| 963 |
+
"align_content": null,
|
| 964 |
+
"align_items": null,
|
| 965 |
+
"align_self": null,
|
| 966 |
+
"border": null,
|
| 967 |
+
"bottom": null,
|
| 968 |
+
"display": null,
|
| 969 |
+
"flex": null,
|
| 970 |
+
"flex_flow": null,
|
| 971 |
+
"grid_area": null,
|
| 972 |
+
"grid_auto_columns": null,
|
| 973 |
+
"grid_auto_flow": null,
|
| 974 |
+
"grid_auto_rows": null,
|
| 975 |
+
"grid_column": null,
|
| 976 |
+
"grid_gap": null,
|
| 977 |
+
"grid_row": null,
|
| 978 |
+
"grid_template_areas": null,
|
| 979 |
+
"grid_template_columns": null,
|
| 980 |
+
"grid_template_rows": null,
|
| 981 |
+
"height": null,
|
| 982 |
+
"justify_content": null,
|
| 983 |
+
"justify_items": null,
|
| 984 |
+
"left": null,
|
| 985 |
+
"margin": null,
|
| 986 |
+
"max_height": null,
|
| 987 |
+
"max_width": null,
|
| 988 |
+
"min_height": null,
|
| 989 |
+
"min_width": null,
|
| 990 |
+
"object_fit": null,
|
| 991 |
+
"object_position": null,
|
| 992 |
+
"order": null,
|
| 993 |
+
"overflow": null,
|
| 994 |
+
"overflow_x": null,
|
| 995 |
+
"overflow_y": null,
|
| 996 |
+
"padding": null,
|
| 997 |
+
"right": null,
|
| 998 |
+
"top": null,
|
| 999 |
+
"visibility": null,
|
| 1000 |
+
"width": null
|
| 1001 |
+
}
|
| 1002 |
+
},
|
| 1003 |
+
"370e5663868f411697bfb24f4e3efa09": {
|
| 1004 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1005 |
+
"model_module_version": "1.5.0",
|
| 1006 |
+
"model_name": "HTMLModel",
|
| 1007 |
+
"state": {
|
| 1008 |
+
"_dom_classes": [],
|
| 1009 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1010 |
+
"_model_module_version": "1.5.0",
|
| 1011 |
+
"_model_name": "HTMLModel",
|
| 1012 |
+
"_view_count": null,
|
| 1013 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1014 |
+
"_view_module_version": "1.5.0",
|
| 1015 |
+
"_view_name": "HTMLView",
|
| 1016 |
+
"description": "",
|
| 1017 |
+
"description_tooltip": null,
|
| 1018 |
+
"layout": "IPY_MODEL_7e89bc79516f405e9684eacdce7b4551",
|
| 1019 |
+
"placeholder": "",
|
| 1020 |
+
"style": "IPY_MODEL_c917f3a000fb44338e4afbeabeaab55f",
|
| 1021 |
+
"value": " 232k/? [00:00<00:00, 12.0MB/s]"
|
| 1022 |
+
}
|
| 1023 |
+
},
|
| 1024 |
+
"3a338ac4d2944030a07843d8ea24e9fd": {
|
| 1025 |
+
"model_module": "@jupyter-widgets/base",
|
| 1026 |
+
"model_module_version": "1.2.0",
|
| 1027 |
+
"model_name": "LayoutModel",
|
| 1028 |
+
"state": {
|
| 1029 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1030 |
+
"_model_module_version": "1.2.0",
|
| 1031 |
+
"_model_name": "LayoutModel",
|
| 1032 |
+
"_view_count": null,
|
| 1033 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1034 |
+
"_view_module_version": "1.2.0",
|
| 1035 |
+
"_view_name": "LayoutView",
|
| 1036 |
+
"align_content": null,
|
| 1037 |
+
"align_items": null,
|
| 1038 |
+
"align_self": null,
|
| 1039 |
+
"border": null,
|
| 1040 |
+
"bottom": null,
|
| 1041 |
+
"display": null,
|
| 1042 |
+
"flex": null,
|
| 1043 |
+
"flex_flow": null,
|
| 1044 |
+
"grid_area": null,
|
| 1045 |
+
"grid_auto_columns": null,
|
| 1046 |
+
"grid_auto_flow": null,
|
| 1047 |
+
"grid_auto_rows": null,
|
| 1048 |
+
"grid_column": null,
|
| 1049 |
+
"grid_gap": null,
|
| 1050 |
+
"grid_row": null,
|
| 1051 |
+
"grid_template_areas": null,
|
| 1052 |
+
"grid_template_columns": null,
|
| 1053 |
+
"grid_template_rows": null,
|
| 1054 |
+
"height": null,
|
| 1055 |
+
"justify_content": null,
|
| 1056 |
+
"justify_items": null,
|
| 1057 |
+
"left": null,
|
| 1058 |
+
"margin": null,
|
| 1059 |
+
"max_height": null,
|
| 1060 |
+
"max_width": null,
|
| 1061 |
+
"min_height": null,
|
| 1062 |
+
"min_width": null,
|
| 1063 |
+
"object_fit": null,
|
| 1064 |
+
"object_position": null,
|
| 1065 |
+
"order": null,
|
| 1066 |
+
"overflow": null,
|
| 1067 |
+
"overflow_x": null,
|
| 1068 |
+
"overflow_y": null,
|
| 1069 |
+
"padding": null,
|
| 1070 |
+
"right": null,
|
| 1071 |
+
"top": null,
|
| 1072 |
+
"visibility": null,
|
| 1073 |
+
"width": null
|
| 1074 |
+
}
|
| 1075 |
+
},
|
| 1076 |
+
"439eba0eb4184c0ab83f65fc26bbe388": {
|
| 1077 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1078 |
+
"model_module_version": "1.5.0",
|
| 1079 |
+
"model_name": "HTMLModel",
|
| 1080 |
+
"state": {
|
| 1081 |
+
"_dom_classes": [],
|
| 1082 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1083 |
+
"_model_module_version": "1.5.0",
|
| 1084 |
+
"_model_name": "HTMLModel",
|
| 1085 |
+
"_view_count": null,
|
| 1086 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1087 |
+
"_view_module_version": "1.5.0",
|
| 1088 |
+
"_view_name": "HTMLView",
|
| 1089 |
+
"description": "",
|
| 1090 |
+
"description_tooltip": null,
|
| 1091 |
+
"layout": "IPY_MODEL_304d50e74ad744cdb3a7cc88739cb923",
|
| 1092 |
+
"placeholder": "",
|
| 1093 |
+
"style": "IPY_MODEL_bfcc6d01c9ff4db698afa4318e7c91ac",
|
| 1094 |
+
"value": "model.safetensors: 100%"
|
| 1095 |
+
}
|
| 1096 |
+
},
|
| 1097 |
+
"4545ff199b874d3680a83918513e1d4b": {
|
| 1098 |
+
"model_module": "@jupyter-widgets/base",
|
| 1099 |
+
"model_module_version": "1.2.0",
|
| 1100 |
+
"model_name": "LayoutModel",
|
| 1101 |
+
"state": {
|
| 1102 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1103 |
+
"_model_module_version": "1.2.0",
|
| 1104 |
+
"_model_name": "LayoutModel",
|
| 1105 |
+
"_view_count": null,
|
| 1106 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1107 |
+
"_view_module_version": "1.2.0",
|
| 1108 |
+
"_view_name": "LayoutView",
|
| 1109 |
+
"align_content": null,
|
| 1110 |
+
"align_items": null,
|
| 1111 |
+
"align_self": null,
|
| 1112 |
+
"border": null,
|
| 1113 |
+
"bottom": null,
|
| 1114 |
+
"display": null,
|
| 1115 |
+
"flex": null,
|
| 1116 |
+
"flex_flow": null,
|
| 1117 |
+
"grid_area": null,
|
| 1118 |
+
"grid_auto_columns": null,
|
| 1119 |
+
"grid_auto_flow": null,
|
| 1120 |
+
"grid_auto_rows": null,
|
| 1121 |
+
"grid_column": null,
|
| 1122 |
+
"grid_gap": null,
|
| 1123 |
+
"grid_row": null,
|
| 1124 |
+
"grid_template_areas": null,
|
| 1125 |
+
"grid_template_columns": null,
|
| 1126 |
+
"grid_template_rows": null,
|
| 1127 |
+
"height": null,
|
| 1128 |
+
"justify_content": null,
|
| 1129 |
+
"justify_items": null,
|
| 1130 |
+
"left": null,
|
| 1131 |
+
"margin": null,
|
| 1132 |
+
"max_height": null,
|
| 1133 |
+
"max_width": null,
|
| 1134 |
+
"min_height": null,
|
| 1135 |
+
"min_width": null,
|
| 1136 |
+
"object_fit": null,
|
| 1137 |
+
"object_position": null,
|
| 1138 |
+
"order": null,
|
| 1139 |
+
"overflow": null,
|
| 1140 |
+
"overflow_x": null,
|
| 1141 |
+
"overflow_y": null,
|
| 1142 |
+
"padding": null,
|
| 1143 |
+
"right": null,
|
| 1144 |
+
"top": null,
|
| 1145 |
+
"visibility": null,
|
| 1146 |
+
"width": null
|
| 1147 |
+
}
|
| 1148 |
+
},
|
| 1149 |
+
"4795a78a75dc439a8da7df58bf738940": {
|
| 1150 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1151 |
+
"model_module_version": "1.5.0",
|
| 1152 |
+
"model_name": "DescriptionStyleModel",
|
| 1153 |
+
"state": {
|
| 1154 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1155 |
+
"_model_module_version": "1.5.0",
|
| 1156 |
+
"_model_name": "DescriptionStyleModel",
|
| 1157 |
+
"_view_count": null,
|
| 1158 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1159 |
+
"_view_module_version": "1.2.0",
|
| 1160 |
+
"_view_name": "StyleView",
|
| 1161 |
+
"description_width": ""
|
| 1162 |
+
}
|
| 1163 |
+
},
|
| 1164 |
+
"4c22e1b396f342ffb90c1b50a0051862": {
|
| 1165 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1166 |
+
"model_module_version": "1.5.0",
|
| 1167 |
+
"model_name": "FloatProgressModel",
|
| 1168 |
+
"state": {
|
| 1169 |
+
"_dom_classes": [],
|
| 1170 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1171 |
+
"_model_module_version": "1.5.0",
|
| 1172 |
+
"_model_name": "FloatProgressModel",
|
| 1173 |
+
"_view_count": null,
|
| 1174 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1175 |
+
"_view_module_version": "1.5.0",
|
| 1176 |
+
"_view_name": "ProgressView",
|
| 1177 |
+
"bar_style": "success",
|
| 1178 |
+
"description": "",
|
| 1179 |
+
"description_tooltip": null,
|
| 1180 |
+
"layout": "IPY_MODEL_a5a9f8607fdd4f9cad7519eca573f3dc",
|
| 1181 |
+
"max": 1,
|
| 1182 |
+
"min": 0,
|
| 1183 |
+
"orientation": "horizontal",
|
| 1184 |
+
"style": "IPY_MODEL_926149594f94457295c60b4fad9cbac7",
|
| 1185 |
+
"value": 1
|
| 1186 |
+
}
|
| 1187 |
+
},
|
| 1188 |
+
"4e32e76c44fb449c8cb767abeb17868a": {
|
| 1189 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1190 |
+
"model_module_version": "1.5.0",
|
| 1191 |
+
"model_name": "DescriptionStyleModel",
|
| 1192 |
+
"state": {
|
| 1193 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1194 |
+
"_model_module_version": "1.5.0",
|
| 1195 |
+
"_model_name": "DescriptionStyleModel",
|
| 1196 |
+
"_view_count": null,
|
| 1197 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1198 |
+
"_view_module_version": "1.2.0",
|
| 1199 |
+
"_view_name": "StyleView",
|
| 1200 |
+
"description_width": ""
|
| 1201 |
+
}
|
| 1202 |
+
},
|
| 1203 |
+
"57e526d188b9414dabb3b1c895373864": {
|
| 1204 |
+
"model_module": "@jupyter-widgets/base",
|
| 1205 |
+
"model_module_version": "1.2.0",
|
| 1206 |
+
"model_name": "LayoutModel",
|
| 1207 |
+
"state": {
|
| 1208 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1209 |
+
"_model_module_version": "1.2.0",
|
| 1210 |
+
"_model_name": "LayoutModel",
|
| 1211 |
+
"_view_count": null,
|
| 1212 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1213 |
+
"_view_module_version": "1.2.0",
|
| 1214 |
+
"_view_name": "LayoutView",
|
| 1215 |
+
"align_content": null,
|
| 1216 |
+
"align_items": null,
|
| 1217 |
+
"align_self": null,
|
| 1218 |
+
"border": null,
|
| 1219 |
+
"bottom": null,
|
| 1220 |
+
"display": null,
|
| 1221 |
+
"flex": null,
|
| 1222 |
+
"flex_flow": null,
|
| 1223 |
+
"grid_area": null,
|
| 1224 |
+
"grid_auto_columns": null,
|
| 1225 |
+
"grid_auto_flow": null,
|
| 1226 |
+
"grid_auto_rows": null,
|
| 1227 |
+
"grid_column": null,
|
| 1228 |
+
"grid_gap": null,
|
| 1229 |
+
"grid_row": null,
|
| 1230 |
+
"grid_template_areas": null,
|
| 1231 |
+
"grid_template_columns": null,
|
| 1232 |
+
"grid_template_rows": null,
|
| 1233 |
+
"height": null,
|
| 1234 |
+
"justify_content": null,
|
| 1235 |
+
"justify_items": null,
|
| 1236 |
+
"left": null,
|
| 1237 |
+
"margin": null,
|
| 1238 |
+
"max_height": null,
|
| 1239 |
+
"max_width": null,
|
| 1240 |
+
"min_height": null,
|
| 1241 |
+
"min_width": null,
|
| 1242 |
+
"object_fit": null,
|
| 1243 |
+
"object_position": null,
|
| 1244 |
+
"order": null,
|
| 1245 |
+
"overflow": null,
|
| 1246 |
+
"overflow_x": null,
|
| 1247 |
+
"overflow_y": null,
|
| 1248 |
+
"padding": null,
|
| 1249 |
+
"right": null,
|
| 1250 |
+
"top": null,
|
| 1251 |
+
"visibility": null,
|
| 1252 |
+
"width": null
|
| 1253 |
+
}
|
| 1254 |
+
},
|
| 1255 |
+
"5837f2c4668646c0a6db2407aebb46e3": {
|
| 1256 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1257 |
+
"model_module_version": "1.5.0",
|
| 1258 |
+
"model_name": "HTMLModel",
|
| 1259 |
+
"state": {
|
| 1260 |
+
"_dom_classes": [],
|
| 1261 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1262 |
+
"_model_module_version": "1.5.0",
|
| 1263 |
+
"_model_name": "HTMLModel",
|
| 1264 |
+
"_view_count": null,
|
| 1265 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1266 |
+
"_view_module_version": "1.5.0",
|
| 1267 |
+
"_view_name": "HTMLView",
|
| 1268 |
+
"description": "",
|
| 1269 |
+
"description_tooltip": null,
|
| 1270 |
+
"layout": "IPY_MODEL_64277772cc30408e8ea29f0e268c8880",
|
| 1271 |
+
"placeholder": "",
|
| 1272 |
+
"style": "IPY_MODEL_5b0d55ea20714104818097bd7d1f509a",
|
| 1273 |
+
"value": " 45.1M/45.1M [00:00<00:00, 112MB/s]"
|
| 1274 |
+
}
|
| 1275 |
+
},
|
| 1276 |
+
"58ab975eaba2485cb0945482c26ecf3d": {
|
| 1277 |
+
"model_module": "@jupyter-widgets/base",
|
| 1278 |
+
"model_module_version": "1.2.0",
|
| 1279 |
+
"model_name": "LayoutModel",
|
| 1280 |
+
"state": {
|
| 1281 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1282 |
+
"_model_module_version": "1.2.0",
|
| 1283 |
+
"_model_name": "LayoutModel",
|
| 1284 |
+
"_view_count": null,
|
| 1285 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1286 |
+
"_view_module_version": "1.2.0",
|
| 1287 |
+
"_view_name": "LayoutView",
|
| 1288 |
+
"align_content": null,
|
| 1289 |
+
"align_items": null,
|
| 1290 |
+
"align_self": null,
|
| 1291 |
+
"border": null,
|
| 1292 |
+
"bottom": null,
|
| 1293 |
+
"display": null,
|
| 1294 |
+
"flex": null,
|
| 1295 |
+
"flex_flow": null,
|
| 1296 |
+
"grid_area": null,
|
| 1297 |
+
"grid_auto_columns": null,
|
| 1298 |
+
"grid_auto_flow": null,
|
| 1299 |
+
"grid_auto_rows": null,
|
| 1300 |
+
"grid_column": null,
|
| 1301 |
+
"grid_gap": null,
|
| 1302 |
+
"grid_row": null,
|
| 1303 |
+
"grid_template_areas": null,
|
| 1304 |
+
"grid_template_columns": null,
|
| 1305 |
+
"grid_template_rows": null,
|
| 1306 |
+
"height": null,
|
| 1307 |
+
"justify_content": null,
|
| 1308 |
+
"justify_items": null,
|
| 1309 |
+
"left": null,
|
| 1310 |
+
"margin": null,
|
| 1311 |
+
"max_height": null,
|
| 1312 |
+
"max_width": null,
|
| 1313 |
+
"min_height": null,
|
| 1314 |
+
"min_width": null,
|
| 1315 |
+
"object_fit": null,
|
| 1316 |
+
"object_position": null,
|
| 1317 |
+
"order": null,
|
| 1318 |
+
"overflow": null,
|
| 1319 |
+
"overflow_x": null,
|
| 1320 |
+
"overflow_y": null,
|
| 1321 |
+
"padding": null,
|
| 1322 |
+
"right": null,
|
| 1323 |
+
"top": null,
|
| 1324 |
+
"visibility": null,
|
| 1325 |
+
"width": null
|
| 1326 |
+
}
|
| 1327 |
+
},
|
| 1328 |
+
"5b0d55ea20714104818097bd7d1f509a": {
|
| 1329 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1330 |
+
"model_module_version": "1.5.0",
|
| 1331 |
+
"model_name": "DescriptionStyleModel",
|
| 1332 |
+
"state": {
|
| 1333 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1334 |
+
"_model_module_version": "1.5.0",
|
| 1335 |
+
"_model_name": "DescriptionStyleModel",
|
| 1336 |
+
"_view_count": null,
|
| 1337 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1338 |
+
"_view_module_version": "1.2.0",
|
| 1339 |
+
"_view_name": "StyleView",
|
| 1340 |
+
"description_width": ""
|
| 1341 |
+
}
|
| 1342 |
+
},
|
| 1343 |
+
"5ba39d9d997a45ca848e3e2ffd0e7307": {
|
| 1344 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1345 |
+
"model_module_version": "1.5.0",
|
| 1346 |
+
"model_name": "HTMLModel",
|
| 1347 |
+
"state": {
|
| 1348 |
+
"_dom_classes": [],
|
| 1349 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1350 |
+
"_model_module_version": "1.5.0",
|
| 1351 |
+
"_model_name": "HTMLModel",
|
| 1352 |
+
"_view_count": null,
|
| 1353 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1354 |
+
"_view_module_version": "1.5.0",
|
| 1355 |
+
"_view_name": "HTMLView",
|
| 1356 |
+
"description": "",
|
| 1357 |
+
"description_tooltip": null,
|
| 1358 |
+
"layout": "IPY_MODEL_128f4312bcdc4166b9e24d8cdd34184d",
|
| 1359 |
+
"placeholder": "",
|
| 1360 |
+
"style": "IPY_MODEL_1b65d6c8540e4f458886d5e7075ab30a",
|
| 1361 |
+
"value": "vocab.txt: "
|
| 1362 |
+
}
|
| 1363 |
+
},
|
| 1364 |
+
"5c3cb981f324446eae642f7c23a539f0": {
|
| 1365 |
+
"model_module": "@jupyter-widgets/base",
|
| 1366 |
+
"model_module_version": "1.2.0",
|
| 1367 |
+
"model_name": "LayoutModel",
|
| 1368 |
+
"state": {
|
| 1369 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1370 |
+
"_model_module_version": "1.2.0",
|
| 1371 |
+
"_model_name": "LayoutModel",
|
| 1372 |
+
"_view_count": null,
|
| 1373 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1374 |
+
"_view_module_version": "1.2.0",
|
| 1375 |
+
"_view_name": "LayoutView",
|
| 1376 |
+
"align_content": null,
|
| 1377 |
+
"align_items": null,
|
| 1378 |
+
"align_self": null,
|
| 1379 |
+
"border": null,
|
| 1380 |
+
"bottom": null,
|
| 1381 |
+
"display": null,
|
| 1382 |
+
"flex": null,
|
| 1383 |
+
"flex_flow": null,
|
| 1384 |
+
"grid_area": null,
|
| 1385 |
+
"grid_auto_columns": null,
|
| 1386 |
+
"grid_auto_flow": null,
|
| 1387 |
+
"grid_auto_rows": null,
|
| 1388 |
+
"grid_column": null,
|
| 1389 |
+
"grid_gap": null,
|
| 1390 |
+
"grid_row": null,
|
| 1391 |
+
"grid_template_areas": null,
|
| 1392 |
+
"grid_template_columns": null,
|
| 1393 |
+
"grid_template_rows": null,
|
| 1394 |
+
"height": null,
|
| 1395 |
+
"justify_content": null,
|
| 1396 |
+
"justify_items": null,
|
| 1397 |
+
"left": null,
|
| 1398 |
+
"margin": null,
|
| 1399 |
+
"max_height": null,
|
| 1400 |
+
"max_width": null,
|
| 1401 |
+
"min_height": null,
|
| 1402 |
+
"min_width": null,
|
| 1403 |
+
"object_fit": null,
|
| 1404 |
+
"object_position": null,
|
| 1405 |
+
"order": null,
|
| 1406 |
+
"overflow": null,
|
| 1407 |
+
"overflow_x": null,
|
| 1408 |
+
"overflow_y": null,
|
| 1409 |
+
"padding": null,
|
| 1410 |
+
"right": null,
|
| 1411 |
+
"top": null,
|
| 1412 |
+
"visibility": null,
|
| 1413 |
+
"width": null
|
| 1414 |
+
}
|
| 1415 |
+
},
|
| 1416 |
+
"5efdceae0bac4c978d3a7226247e237f": {
|
| 1417 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1418 |
+
"model_module_version": "1.5.0",
|
| 1419 |
+
"model_name": "HBoxModel",
|
| 1420 |
+
"state": {
|
| 1421 |
+
"_dom_classes": [],
|
| 1422 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1423 |
+
"_model_module_version": "1.5.0",
|
| 1424 |
+
"_model_name": "HBoxModel",
|
| 1425 |
+
"_view_count": null,
|
| 1426 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1427 |
+
"_view_module_version": "1.5.0",
|
| 1428 |
+
"_view_name": "HBoxView",
|
| 1429 |
+
"box_style": "",
|
| 1430 |
+
"children": [
|
| 1431 |
+
"IPY_MODEL_a39c5c623a3e42448e109fb9ec6bc263",
|
| 1432 |
+
"IPY_MODEL_a6ed2ddb1c6f4d1aa945c5a39372f781",
|
| 1433 |
+
"IPY_MODEL_8cf950b898e142c1af9b4db92019aa4d"
|
| 1434 |
+
],
|
| 1435 |
+
"layout": "IPY_MODEL_8ed7abd0602c43a1bfc0f96d7611d429"
|
| 1436 |
+
}
|
| 1437 |
+
},
|
| 1438 |
+
"5f5e7ff6e4c845b99602a4fa00ad550a": {
|
| 1439 |
+
"model_module": "@jupyter-widgets/base",
|
| 1440 |
+
"model_module_version": "1.2.0",
|
| 1441 |
+
"model_name": "LayoutModel",
|
| 1442 |
+
"state": {
|
| 1443 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1444 |
+
"_model_module_version": "1.2.0",
|
| 1445 |
+
"_model_name": "LayoutModel",
|
| 1446 |
+
"_view_count": null,
|
| 1447 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1448 |
+
"_view_module_version": "1.2.0",
|
| 1449 |
+
"_view_name": "LayoutView",
|
| 1450 |
+
"align_content": null,
|
| 1451 |
+
"align_items": null,
|
| 1452 |
+
"align_self": null,
|
| 1453 |
+
"border": null,
|
| 1454 |
+
"bottom": null,
|
| 1455 |
+
"display": null,
|
| 1456 |
+
"flex": null,
|
| 1457 |
+
"flex_flow": null,
|
| 1458 |
+
"grid_area": null,
|
| 1459 |
+
"grid_auto_columns": null,
|
| 1460 |
+
"grid_auto_flow": null,
|
| 1461 |
+
"grid_auto_rows": null,
|
| 1462 |
+
"grid_column": null,
|
| 1463 |
+
"grid_gap": null,
|
| 1464 |
+
"grid_row": null,
|
| 1465 |
+
"grid_template_areas": null,
|
| 1466 |
+
"grid_template_columns": null,
|
| 1467 |
+
"grid_template_rows": null,
|
| 1468 |
+
"height": null,
|
| 1469 |
+
"justify_content": null,
|
| 1470 |
+
"justify_items": null,
|
| 1471 |
+
"left": null,
|
| 1472 |
+
"margin": null,
|
| 1473 |
+
"max_height": null,
|
| 1474 |
+
"max_width": null,
|
| 1475 |
+
"min_height": null,
|
| 1476 |
+
"min_width": null,
|
| 1477 |
+
"object_fit": null,
|
| 1478 |
+
"object_position": null,
|
| 1479 |
+
"order": null,
|
| 1480 |
+
"overflow": null,
|
| 1481 |
+
"overflow_x": null,
|
| 1482 |
+
"overflow_y": null,
|
| 1483 |
+
"padding": null,
|
| 1484 |
+
"right": null,
|
| 1485 |
+
"top": null,
|
| 1486 |
+
"visibility": null,
|
| 1487 |
+
"width": null
|
| 1488 |
+
}
|
| 1489 |
+
},
|
| 1490 |
+
"64277772cc30408e8ea29f0e268c8880": {
|
| 1491 |
+
"model_module": "@jupyter-widgets/base",
|
| 1492 |
+
"model_module_version": "1.2.0",
|
| 1493 |
+
"model_name": "LayoutModel",
|
| 1494 |
+
"state": {
|
| 1495 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1496 |
+
"_model_module_version": "1.2.0",
|
| 1497 |
+
"_model_name": "LayoutModel",
|
| 1498 |
+
"_view_count": null,
|
| 1499 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1500 |
+
"_view_module_version": "1.2.0",
|
| 1501 |
+
"_view_name": "LayoutView",
|
| 1502 |
+
"align_content": null,
|
| 1503 |
+
"align_items": null,
|
| 1504 |
+
"align_self": null,
|
| 1505 |
+
"border": null,
|
| 1506 |
+
"bottom": null,
|
| 1507 |
+
"display": null,
|
| 1508 |
+
"flex": null,
|
| 1509 |
+
"flex_flow": null,
|
| 1510 |
+
"grid_area": null,
|
| 1511 |
+
"grid_auto_columns": null,
|
| 1512 |
+
"grid_auto_flow": null,
|
| 1513 |
+
"grid_auto_rows": null,
|
| 1514 |
+
"grid_column": null,
|
| 1515 |
+
"grid_gap": null,
|
| 1516 |
+
"grid_row": null,
|
| 1517 |
+
"grid_template_areas": null,
|
| 1518 |
+
"grid_template_columns": null,
|
| 1519 |
+
"grid_template_rows": null,
|
| 1520 |
+
"height": null,
|
| 1521 |
+
"justify_content": null,
|
| 1522 |
+
"justify_items": null,
|
| 1523 |
+
"left": null,
|
| 1524 |
+
"margin": null,
|
| 1525 |
+
"max_height": null,
|
| 1526 |
+
"max_width": null,
|
| 1527 |
+
"min_height": null,
|
| 1528 |
+
"min_width": null,
|
| 1529 |
+
"object_fit": null,
|
| 1530 |
+
"object_position": null,
|
| 1531 |
+
"order": null,
|
| 1532 |
+
"overflow": null,
|
| 1533 |
+
"overflow_x": null,
|
| 1534 |
+
"overflow_y": null,
|
| 1535 |
+
"padding": null,
|
| 1536 |
+
"right": null,
|
| 1537 |
+
"top": null,
|
| 1538 |
+
"visibility": null,
|
| 1539 |
+
"width": null
|
| 1540 |
+
}
|
| 1541 |
+
},
|
| 1542 |
+
"65ba2d78fde14bb2baf5ae1101d7e5ff": {
|
| 1543 |
+
"model_module": "@jupyter-widgets/base",
|
| 1544 |
+
"model_module_version": "1.2.0",
|
| 1545 |
+
"model_name": "LayoutModel",
|
| 1546 |
+
"state": {
|
| 1547 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1548 |
+
"_model_module_version": "1.2.0",
|
| 1549 |
+
"_model_name": "LayoutModel",
|
| 1550 |
+
"_view_count": null,
|
| 1551 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1552 |
+
"_view_module_version": "1.2.0",
|
| 1553 |
+
"_view_name": "LayoutView",
|
| 1554 |
+
"align_content": null,
|
| 1555 |
+
"align_items": null,
|
| 1556 |
+
"align_self": null,
|
| 1557 |
+
"border": null,
|
| 1558 |
+
"bottom": null,
|
| 1559 |
+
"display": null,
|
| 1560 |
+
"flex": null,
|
| 1561 |
+
"flex_flow": null,
|
| 1562 |
+
"grid_area": null,
|
| 1563 |
+
"grid_auto_columns": null,
|
| 1564 |
+
"grid_auto_flow": null,
|
| 1565 |
+
"grid_auto_rows": null,
|
| 1566 |
+
"grid_column": null,
|
| 1567 |
+
"grid_gap": null,
|
| 1568 |
+
"grid_row": null,
|
| 1569 |
+
"grid_template_areas": null,
|
| 1570 |
+
"grid_template_columns": null,
|
| 1571 |
+
"grid_template_rows": null,
|
| 1572 |
+
"height": null,
|
| 1573 |
+
"justify_content": null,
|
| 1574 |
+
"justify_items": null,
|
| 1575 |
+
"left": null,
|
| 1576 |
+
"margin": null,
|
| 1577 |
+
"max_height": null,
|
| 1578 |
+
"max_width": null,
|
| 1579 |
+
"min_height": null,
|
| 1580 |
+
"min_width": null,
|
| 1581 |
+
"object_fit": null,
|
| 1582 |
+
"object_position": null,
|
| 1583 |
+
"order": null,
|
| 1584 |
+
"overflow": null,
|
| 1585 |
+
"overflow_x": null,
|
| 1586 |
+
"overflow_y": null,
|
| 1587 |
+
"padding": null,
|
| 1588 |
+
"right": null,
|
| 1589 |
+
"top": null,
|
| 1590 |
+
"visibility": null,
|
| 1591 |
+
"width": null
|
| 1592 |
+
}
|
| 1593 |
+
},
|
| 1594 |
+
"7e21c6a9c7f44496b6f28513caefb631": {
|
| 1595 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1596 |
+
"model_module_version": "1.5.0",
|
| 1597 |
+
"model_name": "HBoxModel",
|
| 1598 |
+
"state": {
|
| 1599 |
+
"_dom_classes": [],
|
| 1600 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1601 |
+
"_model_module_version": "1.5.0",
|
| 1602 |
+
"_model_name": "HBoxModel",
|
| 1603 |
+
"_view_count": null,
|
| 1604 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1605 |
+
"_view_module_version": "1.5.0",
|
| 1606 |
+
"_view_name": "HBoxView",
|
| 1607 |
+
"box_style": "",
|
| 1608 |
+
"children": [
|
| 1609 |
+
"IPY_MODEL_439eba0eb4184c0ab83f65fc26bbe388",
|
| 1610 |
+
"IPY_MODEL_eee695744ec64aa7b71b9e85968c6f8f",
|
| 1611 |
+
"IPY_MODEL_c4ecdc9d982f49129368893c1c0aece9"
|
| 1612 |
+
],
|
| 1613 |
+
"layout": "IPY_MODEL_5f5e7ff6e4c845b99602a4fa00ad550a"
|
| 1614 |
+
}
|
| 1615 |
+
},
|
| 1616 |
+
"7e89bc79516f405e9684eacdce7b4551": {
|
| 1617 |
+
"model_module": "@jupyter-widgets/base",
|
| 1618 |
+
"model_module_version": "1.2.0",
|
| 1619 |
+
"model_name": "LayoutModel",
|
| 1620 |
+
"state": {
|
| 1621 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1622 |
+
"_model_module_version": "1.2.0",
|
| 1623 |
+
"_model_name": "LayoutModel",
|
| 1624 |
+
"_view_count": null,
|
| 1625 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1626 |
+
"_view_module_version": "1.2.0",
|
| 1627 |
+
"_view_name": "LayoutView",
|
| 1628 |
+
"align_content": null,
|
| 1629 |
+
"align_items": null,
|
| 1630 |
+
"align_self": null,
|
| 1631 |
+
"border": null,
|
| 1632 |
+
"bottom": null,
|
| 1633 |
+
"display": null,
|
| 1634 |
+
"flex": null,
|
| 1635 |
+
"flex_flow": null,
|
| 1636 |
+
"grid_area": null,
|
| 1637 |
+
"grid_auto_columns": null,
|
| 1638 |
+
"grid_auto_flow": null,
|
| 1639 |
+
"grid_auto_rows": null,
|
| 1640 |
+
"grid_column": null,
|
| 1641 |
+
"grid_gap": null,
|
| 1642 |
+
"grid_row": null,
|
| 1643 |
+
"grid_template_areas": null,
|
| 1644 |
+
"grid_template_columns": null,
|
| 1645 |
+
"grid_template_rows": null,
|
| 1646 |
+
"height": null,
|
| 1647 |
+
"justify_content": null,
|
| 1648 |
+
"justify_items": null,
|
| 1649 |
+
"left": null,
|
| 1650 |
+
"margin": null,
|
| 1651 |
+
"max_height": null,
|
| 1652 |
+
"max_width": null,
|
| 1653 |
+
"min_height": null,
|
| 1654 |
+
"min_width": null,
|
| 1655 |
+
"object_fit": null,
|
| 1656 |
+
"object_position": null,
|
| 1657 |
+
"order": null,
|
| 1658 |
+
"overflow": null,
|
| 1659 |
+
"overflow_x": null,
|
| 1660 |
+
"overflow_y": null,
|
| 1661 |
+
"padding": null,
|
| 1662 |
+
"right": null,
|
| 1663 |
+
"top": null,
|
| 1664 |
+
"visibility": null,
|
| 1665 |
+
"width": null
|
| 1666 |
+
}
|
| 1667 |
+
},
|
| 1668 |
+
"8226a55726c54abba3a48dbfa8e1b6f6": {
|
| 1669 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1670 |
+
"model_module_version": "1.5.0",
|
| 1671 |
+
"model_name": "DescriptionStyleModel",
|
| 1672 |
+
"state": {
|
| 1673 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1674 |
+
"_model_module_version": "1.5.0",
|
| 1675 |
+
"_model_name": "DescriptionStyleModel",
|
| 1676 |
+
"_view_count": null,
|
| 1677 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1678 |
+
"_view_module_version": "1.2.0",
|
| 1679 |
+
"_view_name": "StyleView",
|
| 1680 |
+
"description_width": ""
|
| 1681 |
+
}
|
| 1682 |
+
},
|
| 1683 |
+
"828b227361fe45cd83964149e7475503": {
|
| 1684 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1685 |
+
"model_module_version": "1.5.0",
|
| 1686 |
+
"model_name": "ProgressStyleModel",
|
| 1687 |
+
"state": {
|
| 1688 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1689 |
+
"_model_module_version": "1.5.0",
|
| 1690 |
+
"_model_name": "ProgressStyleModel",
|
| 1691 |
+
"_view_count": null,
|
| 1692 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1693 |
+
"_view_module_version": "1.2.0",
|
| 1694 |
+
"_view_name": "StyleView",
|
| 1695 |
+
"bar_color": null,
|
| 1696 |
+
"description_width": ""
|
| 1697 |
+
}
|
| 1698 |
+
},
|
| 1699 |
+
"86a3c1a4e9eb4989b23364f21e5df531": {
|
| 1700 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1701 |
+
"model_module_version": "1.5.0",
|
| 1702 |
+
"model_name": "HBoxModel",
|
| 1703 |
+
"state": {
|
| 1704 |
+
"_dom_classes": [],
|
| 1705 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1706 |
+
"_model_module_version": "1.5.0",
|
| 1707 |
+
"_model_name": "HBoxModel",
|
| 1708 |
+
"_view_count": null,
|
| 1709 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1710 |
+
"_view_module_version": "1.5.0",
|
| 1711 |
+
"_view_name": "HBoxView",
|
| 1712 |
+
"box_style": "",
|
| 1713 |
+
"children": [
|
| 1714 |
+
"IPY_MODEL_5ba39d9d997a45ca848e3e2ffd0e7307",
|
| 1715 |
+
"IPY_MODEL_4c22e1b396f342ffb90c1b50a0051862",
|
| 1716 |
+
"IPY_MODEL_370e5663868f411697bfb24f4e3efa09"
|
| 1717 |
+
],
|
| 1718 |
+
"layout": "IPY_MODEL_3a338ac4d2944030a07843d8ea24e9fd"
|
| 1719 |
+
}
|
| 1720 |
+
},
|
| 1721 |
+
"8cf950b898e142c1af9b4db92019aa4d": {
|
| 1722 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1723 |
+
"model_module_version": "1.5.0",
|
| 1724 |
+
"model_name": "HTMLModel",
|
| 1725 |
+
"state": {
|
| 1726 |
+
"_dom_classes": [],
|
| 1727 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1728 |
+
"_model_module_version": "1.5.0",
|
| 1729 |
+
"_model_name": "HTMLModel",
|
| 1730 |
+
"_view_count": null,
|
| 1731 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1732 |
+
"_view_module_version": "1.5.0",
|
| 1733 |
+
"_view_name": "HTMLView",
|
| 1734 |
+
"description": "",
|
| 1735 |
+
"description_tooltip": null,
|
| 1736 |
+
"layout": "IPY_MODEL_57e526d188b9414dabb3b1c895373864",
|
| 1737 |
+
"placeholder": "",
|
| 1738 |
+
"style": "IPY_MODEL_8226a55726c54abba3a48dbfa8e1b6f6",
|
| 1739 |
+
"value": " 286/286 [00:00<00:00, 25.2kB/s]"
|
| 1740 |
+
}
|
| 1741 |
+
},
|
| 1742 |
+
"8ed7abd0602c43a1bfc0f96d7611d429": {
|
| 1743 |
+
"model_module": "@jupyter-widgets/base",
|
| 1744 |
+
"model_module_version": "1.2.0",
|
| 1745 |
+
"model_name": "LayoutModel",
|
| 1746 |
+
"state": {
|
| 1747 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1748 |
+
"_model_module_version": "1.2.0",
|
| 1749 |
+
"_model_name": "LayoutModel",
|
| 1750 |
+
"_view_count": null,
|
| 1751 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1752 |
+
"_view_module_version": "1.2.0",
|
| 1753 |
+
"_view_name": "LayoutView",
|
| 1754 |
+
"align_content": null,
|
| 1755 |
+
"align_items": null,
|
| 1756 |
+
"align_self": null,
|
| 1757 |
+
"border": null,
|
| 1758 |
+
"bottom": null,
|
| 1759 |
+
"display": null,
|
| 1760 |
+
"flex": null,
|
| 1761 |
+
"flex_flow": null,
|
| 1762 |
+
"grid_area": null,
|
| 1763 |
+
"grid_auto_columns": null,
|
| 1764 |
+
"grid_auto_flow": null,
|
| 1765 |
+
"grid_auto_rows": null,
|
| 1766 |
+
"grid_column": null,
|
| 1767 |
+
"grid_gap": null,
|
| 1768 |
+
"grid_row": null,
|
| 1769 |
+
"grid_template_areas": null,
|
| 1770 |
+
"grid_template_columns": null,
|
| 1771 |
+
"grid_template_rows": null,
|
| 1772 |
+
"height": null,
|
| 1773 |
+
"justify_content": null,
|
| 1774 |
+
"justify_items": null,
|
| 1775 |
+
"left": null,
|
| 1776 |
+
"margin": null,
|
| 1777 |
+
"max_height": null,
|
| 1778 |
+
"max_width": null,
|
| 1779 |
+
"min_height": null,
|
| 1780 |
+
"min_width": null,
|
| 1781 |
+
"object_fit": null,
|
| 1782 |
+
"object_position": null,
|
| 1783 |
+
"order": null,
|
| 1784 |
+
"overflow": null,
|
| 1785 |
+
"overflow_x": null,
|
| 1786 |
+
"overflow_y": null,
|
| 1787 |
+
"padding": null,
|
| 1788 |
+
"right": null,
|
| 1789 |
+
"top": null,
|
| 1790 |
+
"visibility": null,
|
| 1791 |
+
"width": null
|
| 1792 |
+
}
|
| 1793 |
+
},
|
| 1794 |
+
"926149594f94457295c60b4fad9cbac7": {
|
| 1795 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1796 |
+
"model_module_version": "1.5.0",
|
| 1797 |
+
"model_name": "ProgressStyleModel",
|
| 1798 |
+
"state": {
|
| 1799 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1800 |
+
"_model_module_version": "1.5.0",
|
| 1801 |
+
"_model_name": "ProgressStyleModel",
|
| 1802 |
+
"_view_count": null,
|
| 1803 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1804 |
+
"_view_module_version": "1.2.0",
|
| 1805 |
+
"_view_name": "StyleView",
|
| 1806 |
+
"bar_color": null,
|
| 1807 |
+
"description_width": ""
|
| 1808 |
+
}
|
| 1809 |
+
},
|
| 1810 |
+
"a39c5c623a3e42448e109fb9ec6bc263": {
|
| 1811 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1812 |
+
"model_module_version": "1.5.0",
|
| 1813 |
+
"model_name": "HTMLModel",
|
| 1814 |
+
"state": {
|
| 1815 |
+
"_dom_classes": [],
|
| 1816 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1817 |
+
"_model_module_version": "1.5.0",
|
| 1818 |
+
"_model_name": "HTMLModel",
|
| 1819 |
+
"_view_count": null,
|
| 1820 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1821 |
+
"_view_module_version": "1.5.0",
|
| 1822 |
+
"_view_name": "HTMLView",
|
| 1823 |
+
"description": "",
|
| 1824 |
+
"description_tooltip": null,
|
| 1825 |
+
"layout": "IPY_MODEL_65ba2d78fde14bb2baf5ae1101d7e5ff",
|
| 1826 |
+
"placeholder": "",
|
| 1827 |
+
"style": "IPY_MODEL_4795a78a75dc439a8da7df58bf738940",
|
| 1828 |
+
"value": "config.json: 100%"
|
| 1829 |
+
}
|
| 1830 |
+
},
|
| 1831 |
+
"a5a9f8607fdd4f9cad7519eca573f3dc": {
|
| 1832 |
+
"model_module": "@jupyter-widgets/base",
|
| 1833 |
+
"model_module_version": "1.2.0",
|
| 1834 |
+
"model_name": "LayoutModel",
|
| 1835 |
+
"state": {
|
| 1836 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1837 |
+
"_model_module_version": "1.2.0",
|
| 1838 |
+
"_model_name": "LayoutModel",
|
| 1839 |
+
"_view_count": null,
|
| 1840 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1841 |
+
"_view_module_version": "1.2.0",
|
| 1842 |
+
"_view_name": "LayoutView",
|
| 1843 |
+
"align_content": null,
|
| 1844 |
+
"align_items": null,
|
| 1845 |
+
"align_self": null,
|
| 1846 |
+
"border": null,
|
| 1847 |
+
"bottom": null,
|
| 1848 |
+
"display": null,
|
| 1849 |
+
"flex": null,
|
| 1850 |
+
"flex_flow": null,
|
| 1851 |
+
"grid_area": null,
|
| 1852 |
+
"grid_auto_columns": null,
|
| 1853 |
+
"grid_auto_flow": null,
|
| 1854 |
+
"grid_auto_rows": null,
|
| 1855 |
+
"grid_column": null,
|
| 1856 |
+
"grid_gap": null,
|
| 1857 |
+
"grid_row": null,
|
| 1858 |
+
"grid_template_areas": null,
|
| 1859 |
+
"grid_template_columns": null,
|
| 1860 |
+
"grid_template_rows": null,
|
| 1861 |
+
"height": null,
|
| 1862 |
+
"justify_content": null,
|
| 1863 |
+
"justify_items": null,
|
| 1864 |
+
"left": null,
|
| 1865 |
+
"margin": null,
|
| 1866 |
+
"max_height": null,
|
| 1867 |
+
"max_width": null,
|
| 1868 |
+
"min_height": null,
|
| 1869 |
+
"min_width": null,
|
| 1870 |
+
"object_fit": null,
|
| 1871 |
+
"object_position": null,
|
| 1872 |
+
"order": null,
|
| 1873 |
+
"overflow": null,
|
| 1874 |
+
"overflow_x": null,
|
| 1875 |
+
"overflow_y": null,
|
| 1876 |
+
"padding": null,
|
| 1877 |
+
"right": null,
|
| 1878 |
+
"top": null,
|
| 1879 |
+
"visibility": null,
|
| 1880 |
+
"width": "20px"
|
| 1881 |
+
}
|
| 1882 |
+
},
|
| 1883 |
+
"a6ed2ddb1c6f4d1aa945c5a39372f781": {
|
| 1884 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1885 |
+
"model_module_version": "1.5.0",
|
| 1886 |
+
"model_name": "FloatProgressModel",
|
| 1887 |
+
"state": {
|
| 1888 |
+
"_dom_classes": [],
|
| 1889 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1890 |
+
"_model_module_version": "1.5.0",
|
| 1891 |
+
"_model_name": "FloatProgressModel",
|
| 1892 |
+
"_view_count": null,
|
| 1893 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1894 |
+
"_view_module_version": "1.5.0",
|
| 1895 |
+
"_view_name": "ProgressView",
|
| 1896 |
+
"bar_style": "success",
|
| 1897 |
+
"description": "",
|
| 1898 |
+
"description_tooltip": null,
|
| 1899 |
+
"layout": "IPY_MODEL_4545ff199b874d3680a83918513e1d4b",
|
| 1900 |
+
"max": 286,
|
| 1901 |
+
"min": 0,
|
| 1902 |
+
"orientation": "horizontal",
|
| 1903 |
+
"style": "IPY_MODEL_cad8fd90586443778568a1babb8c40e6",
|
| 1904 |
+
"value": 286
|
| 1905 |
+
}
|
| 1906 |
+
},
|
| 1907 |
+
"ab61b90c1a5b4a2b9bb5c9d5a215bb3f": {
|
| 1908 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1909 |
+
"model_module_version": "1.5.0",
|
| 1910 |
+
"model_name": "HTMLModel",
|
| 1911 |
+
"state": {
|
| 1912 |
+
"_dom_classes": [],
|
| 1913 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1914 |
+
"_model_module_version": "1.5.0",
|
| 1915 |
+
"_model_name": "HTMLModel",
|
| 1916 |
+
"_view_count": null,
|
| 1917 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1918 |
+
"_view_module_version": "1.5.0",
|
| 1919 |
+
"_view_name": "HTMLView",
|
| 1920 |
+
"description": "",
|
| 1921 |
+
"description_tooltip": null,
|
| 1922 |
+
"layout": "IPY_MODEL_bf8eb066cdaf4ac096dc14392d085daf",
|
| 1923 |
+
"placeholder": "",
|
| 1924 |
+
"style": "IPY_MODEL_4e32e76c44fb449c8cb767abeb17868a",
|
| 1925 |
+
"value": "pytorch_model.bin: 100%"
|
| 1926 |
+
}
|
| 1927 |
+
},
|
| 1928 |
+
"b2bf751bb96746e4a828241f70e52050": {
|
| 1929 |
+
"model_module": "@jupyter-widgets/base",
|
| 1930 |
+
"model_module_version": "1.2.0",
|
| 1931 |
+
"model_name": "LayoutModel",
|
| 1932 |
+
"state": {
|
| 1933 |
+
"_model_module": "@jupyter-widgets/base",
|
| 1934 |
+
"_model_module_version": "1.2.0",
|
| 1935 |
+
"_model_name": "LayoutModel",
|
| 1936 |
+
"_view_count": null,
|
| 1937 |
+
"_view_module": "@jupyter-widgets/base",
|
| 1938 |
+
"_view_module_version": "1.2.0",
|
| 1939 |
+
"_view_name": "LayoutView",
|
| 1940 |
+
"align_content": null,
|
| 1941 |
+
"align_items": null,
|
| 1942 |
+
"align_self": null,
|
| 1943 |
+
"border": null,
|
| 1944 |
+
"bottom": null,
|
| 1945 |
+
"display": null,
|
| 1946 |
+
"flex": null,
|
| 1947 |
+
"flex_flow": null,
|
| 1948 |
+
"grid_area": null,
|
| 1949 |
+
"grid_auto_columns": null,
|
| 1950 |
+
"grid_auto_flow": null,
|
| 1951 |
+
"grid_auto_rows": null,
|
| 1952 |
+
"grid_column": null,
|
| 1953 |
+
"grid_gap": null,
|
| 1954 |
+
"grid_row": null,
|
| 1955 |
+
"grid_template_areas": null,
|
| 1956 |
+
"grid_template_columns": null,
|
| 1957 |
+
"grid_template_rows": null,
|
| 1958 |
+
"height": null,
|
| 1959 |
+
"justify_content": null,
|
| 1960 |
+
"justify_items": null,
|
| 1961 |
+
"left": null,
|
| 1962 |
+
"margin": null,
|
| 1963 |
+
"max_height": null,
|
| 1964 |
+
"max_width": null,
|
| 1965 |
+
"min_height": null,
|
| 1966 |
+
"min_width": null,
|
| 1967 |
+
"object_fit": null,
|
| 1968 |
+
"object_position": null,
|
| 1969 |
+
"order": null,
|
| 1970 |
+
"overflow": null,
|
| 1971 |
+
"overflow_x": null,
|
| 1972 |
+
"overflow_y": null,
|
| 1973 |
+
"padding": null,
|
| 1974 |
+
"right": null,
|
| 1975 |
+
"top": null,
|
| 1976 |
+
"visibility": null,
|
| 1977 |
+
"width": null
|
| 1978 |
+
}
|
| 1979 |
+
},
|
| 1980 |
+
"bdf500351aea42698c6d6dd5a99021f3": {
|
| 1981 |
+
"model_module": "@jupyter-widgets/controls",
|
| 1982 |
+
"model_module_version": "1.5.0",
|
| 1983 |
+
"model_name": "HBoxModel",
|
| 1984 |
+
"state": {
|
| 1985 |
+
"_dom_classes": [],
|
| 1986 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 1987 |
+
"_model_module_version": "1.5.0",
|
| 1988 |
+
"_model_name": "HBoxModel",
|
| 1989 |
+
"_view_count": null,
|
| 1990 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 1991 |
+
"_view_module_version": "1.5.0",
|
| 1992 |
+
"_view_name": "HBoxView",
|
| 1993 |
+
"box_style": "",
|
| 1994 |
+
"children": [
|
| 1995 |
+
"IPY_MODEL_ab61b90c1a5b4a2b9bb5c9d5a215bb3f",
|
| 1996 |
+
"IPY_MODEL_dc03fed540b74f3aa4a1b17ebf2c81d3",
|
| 1997 |
+
"IPY_MODEL_5837f2c4668646c0a6db2407aebb46e3"
|
| 1998 |
+
],
|
| 1999 |
+
"layout": "IPY_MODEL_edeb423e9ff84e5c8a0d790368d68bba"
|
| 2000 |
+
}
|
| 2001 |
+
},
|
| 2002 |
+
"bf8eb066cdaf4ac096dc14392d085daf": {
|
| 2003 |
+
"model_module": "@jupyter-widgets/base",
|
| 2004 |
+
"model_module_version": "1.2.0",
|
| 2005 |
+
"model_name": "LayoutModel",
|
| 2006 |
+
"state": {
|
| 2007 |
+
"_model_module": "@jupyter-widgets/base",
|
| 2008 |
+
"_model_module_version": "1.2.0",
|
| 2009 |
+
"_model_name": "LayoutModel",
|
| 2010 |
+
"_view_count": null,
|
| 2011 |
+
"_view_module": "@jupyter-widgets/base",
|
| 2012 |
+
"_view_module_version": "1.2.0",
|
| 2013 |
+
"_view_name": "LayoutView",
|
| 2014 |
+
"align_content": null,
|
| 2015 |
+
"align_items": null,
|
| 2016 |
+
"align_self": null,
|
| 2017 |
+
"border": null,
|
| 2018 |
+
"bottom": null,
|
| 2019 |
+
"display": null,
|
| 2020 |
+
"flex": null,
|
| 2021 |
+
"flex_flow": null,
|
| 2022 |
+
"grid_area": null,
|
| 2023 |
+
"grid_auto_columns": null,
|
| 2024 |
+
"grid_auto_flow": null,
|
| 2025 |
+
"grid_auto_rows": null,
|
| 2026 |
+
"grid_column": null,
|
| 2027 |
+
"grid_gap": null,
|
| 2028 |
+
"grid_row": null,
|
| 2029 |
+
"grid_template_areas": null,
|
| 2030 |
+
"grid_template_columns": null,
|
| 2031 |
+
"grid_template_rows": null,
|
| 2032 |
+
"height": null,
|
| 2033 |
+
"justify_content": null,
|
| 2034 |
+
"justify_items": null,
|
| 2035 |
+
"left": null,
|
| 2036 |
+
"margin": null,
|
| 2037 |
+
"max_height": null,
|
| 2038 |
+
"max_width": null,
|
| 2039 |
+
"min_height": null,
|
| 2040 |
+
"min_width": null,
|
| 2041 |
+
"object_fit": null,
|
| 2042 |
+
"object_position": null,
|
| 2043 |
+
"order": null,
|
| 2044 |
+
"overflow": null,
|
| 2045 |
+
"overflow_x": null,
|
| 2046 |
+
"overflow_y": null,
|
| 2047 |
+
"padding": null,
|
| 2048 |
+
"right": null,
|
| 2049 |
+
"top": null,
|
| 2050 |
+
"visibility": null,
|
| 2051 |
+
"width": null
|
| 2052 |
+
}
|
| 2053 |
+
},
|
| 2054 |
+
"bfcc6d01c9ff4db698afa4318e7c91ac": {
|
| 2055 |
+
"model_module": "@jupyter-widgets/controls",
|
| 2056 |
+
"model_module_version": "1.5.0",
|
| 2057 |
+
"model_name": "DescriptionStyleModel",
|
| 2058 |
+
"state": {
|
| 2059 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 2060 |
+
"_model_module_version": "1.5.0",
|
| 2061 |
+
"_model_name": "DescriptionStyleModel",
|
| 2062 |
+
"_view_count": null,
|
| 2063 |
+
"_view_module": "@jupyter-widgets/base",
|
| 2064 |
+
"_view_module_version": "1.2.0",
|
| 2065 |
+
"_view_name": "StyleView",
|
| 2066 |
+
"description_width": ""
|
| 2067 |
+
}
|
| 2068 |
+
},
|
| 2069 |
+
"c4ecdc9d982f49129368893c1c0aece9": {
|
| 2070 |
+
"model_module": "@jupyter-widgets/controls",
|
| 2071 |
+
"model_module_version": "1.5.0",
|
| 2072 |
+
"model_name": "HTMLModel",
|
| 2073 |
+
"state": {
|
| 2074 |
+
"_dom_classes": [],
|
| 2075 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 2076 |
+
"_model_module_version": "1.5.0",
|
| 2077 |
+
"_model_name": "HTMLModel",
|
| 2078 |
+
"_view_count": null,
|
| 2079 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 2080 |
+
"_view_module_version": "1.5.0",
|
| 2081 |
+
"_view_name": "HTMLView",
|
| 2082 |
+
"description": "",
|
| 2083 |
+
"description_tooltip": null,
|
| 2084 |
+
"layout": "IPY_MODEL_58ab975eaba2485cb0945482c26ecf3d",
|
| 2085 |
+
"placeholder": "",
|
| 2086 |
+
"style": "IPY_MODEL_d0b4e43ab5cd4edda6cc061b36bf10a3",
|
| 2087 |
+
"value": " 45.1M/45.1M [00:00<00:00, 89.1MB/s]"
|
| 2088 |
+
}
|
| 2089 |
+
},
|
| 2090 |
+
"c917f3a000fb44338e4afbeabeaab55f": {
|
| 2091 |
+
"model_module": "@jupyter-widgets/controls",
|
| 2092 |
+
"model_module_version": "1.5.0",
|
| 2093 |
+
"model_name": "DescriptionStyleModel",
|
| 2094 |
+
"state": {
|
| 2095 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 2096 |
+
"_model_module_version": "1.5.0",
|
| 2097 |
+
"_model_name": "DescriptionStyleModel",
|
| 2098 |
+
"_view_count": null,
|
| 2099 |
+
"_view_module": "@jupyter-widgets/base",
|
| 2100 |
+
"_view_module_version": "1.2.0",
|
| 2101 |
+
"_view_name": "StyleView",
|
| 2102 |
+
"description_width": ""
|
| 2103 |
+
}
|
| 2104 |
+
},
|
| 2105 |
+
"cad8fd90586443778568a1babb8c40e6": {
|
| 2106 |
+
"model_module": "@jupyter-widgets/controls",
|
| 2107 |
+
"model_module_version": "1.5.0",
|
| 2108 |
+
"model_name": "ProgressStyleModel",
|
| 2109 |
+
"state": {
|
| 2110 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 2111 |
+
"_model_module_version": "1.5.0",
|
| 2112 |
+
"_model_name": "ProgressStyleModel",
|
| 2113 |
+
"_view_count": null,
|
| 2114 |
+
"_view_module": "@jupyter-widgets/base",
|
| 2115 |
+
"_view_module_version": "1.2.0",
|
| 2116 |
+
"_view_name": "StyleView",
|
| 2117 |
+
"bar_color": null,
|
| 2118 |
+
"description_width": ""
|
| 2119 |
+
}
|
| 2120 |
+
},
|
| 2121 |
+
"d0b4e43ab5cd4edda6cc061b36bf10a3": {
|
| 2122 |
+
"model_module": "@jupyter-widgets/controls",
|
| 2123 |
+
"model_module_version": "1.5.0",
|
| 2124 |
+
"model_name": "DescriptionStyleModel",
|
| 2125 |
+
"state": {
|
| 2126 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 2127 |
+
"_model_module_version": "1.5.0",
|
| 2128 |
+
"_model_name": "DescriptionStyleModel",
|
| 2129 |
+
"_view_count": null,
|
| 2130 |
+
"_view_module": "@jupyter-widgets/base",
|
| 2131 |
+
"_view_module_version": "1.2.0",
|
| 2132 |
+
"_view_name": "StyleView",
|
| 2133 |
+
"description_width": ""
|
| 2134 |
+
}
|
| 2135 |
+
},
|
| 2136 |
+
"dc03fed540b74f3aa4a1b17ebf2c81d3": {
|
| 2137 |
+
"model_module": "@jupyter-widgets/controls",
|
| 2138 |
+
"model_module_version": "1.5.0",
|
| 2139 |
+
"model_name": "FloatProgressModel",
|
| 2140 |
+
"state": {
|
| 2141 |
+
"_dom_classes": [],
|
| 2142 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 2143 |
+
"_model_module_version": "1.5.0",
|
| 2144 |
+
"_model_name": "FloatProgressModel",
|
| 2145 |
+
"_view_count": null,
|
| 2146 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 2147 |
+
"_view_module_version": "1.5.0",
|
| 2148 |
+
"_view_name": "ProgressView",
|
| 2149 |
+
"bar_style": "success",
|
| 2150 |
+
"description": "",
|
| 2151 |
+
"description_tooltip": null,
|
| 2152 |
+
"layout": "IPY_MODEL_5c3cb981f324446eae642f7c23a539f0",
|
| 2153 |
+
"max": 45106985,
|
| 2154 |
+
"min": 0,
|
| 2155 |
+
"orientation": "horizontal",
|
| 2156 |
+
"style": "IPY_MODEL_2fe9614fe5984fa6b887d1e1b3e18b04",
|
| 2157 |
+
"value": 45106985
|
| 2158 |
+
}
|
| 2159 |
+
},
|
| 2160 |
+
"edeb423e9ff84e5c8a0d790368d68bba": {
|
| 2161 |
+
"model_module": "@jupyter-widgets/base",
|
| 2162 |
+
"model_module_version": "1.2.0",
|
| 2163 |
+
"model_name": "LayoutModel",
|
| 2164 |
+
"state": {
|
| 2165 |
+
"_model_module": "@jupyter-widgets/base",
|
| 2166 |
+
"_model_module_version": "1.2.0",
|
| 2167 |
+
"_model_name": "LayoutModel",
|
| 2168 |
+
"_view_count": null,
|
| 2169 |
+
"_view_module": "@jupyter-widgets/base",
|
| 2170 |
+
"_view_module_version": "1.2.0",
|
| 2171 |
+
"_view_name": "LayoutView",
|
| 2172 |
+
"align_content": null,
|
| 2173 |
+
"align_items": null,
|
| 2174 |
+
"align_self": null,
|
| 2175 |
+
"border": null,
|
| 2176 |
+
"bottom": null,
|
| 2177 |
+
"display": null,
|
| 2178 |
+
"flex": null,
|
| 2179 |
+
"flex_flow": null,
|
| 2180 |
+
"grid_area": null,
|
| 2181 |
+
"grid_auto_columns": null,
|
| 2182 |
+
"grid_auto_flow": null,
|
| 2183 |
+
"grid_auto_rows": null,
|
| 2184 |
+
"grid_column": null,
|
| 2185 |
+
"grid_gap": null,
|
| 2186 |
+
"grid_row": null,
|
| 2187 |
+
"grid_template_areas": null,
|
| 2188 |
+
"grid_template_columns": null,
|
| 2189 |
+
"grid_template_rows": null,
|
| 2190 |
+
"height": null,
|
| 2191 |
+
"justify_content": null,
|
| 2192 |
+
"justify_items": null,
|
| 2193 |
+
"left": null,
|
| 2194 |
+
"margin": null,
|
| 2195 |
+
"max_height": null,
|
| 2196 |
+
"max_width": null,
|
| 2197 |
+
"min_height": null,
|
| 2198 |
+
"min_width": null,
|
| 2199 |
+
"object_fit": null,
|
| 2200 |
+
"object_position": null,
|
| 2201 |
+
"order": null,
|
| 2202 |
+
"overflow": null,
|
| 2203 |
+
"overflow_x": null,
|
| 2204 |
+
"overflow_y": null,
|
| 2205 |
+
"padding": null,
|
| 2206 |
+
"right": null,
|
| 2207 |
+
"top": null,
|
| 2208 |
+
"visibility": null,
|
| 2209 |
+
"width": null
|
| 2210 |
+
}
|
| 2211 |
+
},
|
| 2212 |
+
"eee695744ec64aa7b71b9e85968c6f8f": {
|
| 2213 |
+
"model_module": "@jupyter-widgets/controls",
|
| 2214 |
+
"model_module_version": "1.5.0",
|
| 2215 |
+
"model_name": "FloatProgressModel",
|
| 2216 |
+
"state": {
|
| 2217 |
+
"_dom_classes": [],
|
| 2218 |
+
"_model_module": "@jupyter-widgets/controls",
|
| 2219 |
+
"_model_module_version": "1.5.0",
|
| 2220 |
+
"_model_name": "FloatProgressModel",
|
| 2221 |
+
"_view_count": null,
|
| 2222 |
+
"_view_module": "@jupyter-widgets/controls",
|
| 2223 |
+
"_view_module_version": "1.5.0",
|
| 2224 |
+
"_view_name": "ProgressView",
|
| 2225 |
+
"bar_style": "success",
|
| 2226 |
+
"description": "",
|
| 2227 |
+
"description_tooltip": null,
|
| 2228 |
+
"layout": "IPY_MODEL_b2bf751bb96746e4a828241f70e52050",
|
| 2229 |
+
"max": 45084768,
|
| 2230 |
+
"min": 0,
|
| 2231 |
+
"orientation": "horizontal",
|
| 2232 |
+
"style": "IPY_MODEL_828b227361fe45cd83964149e7475503",
|
| 2233 |
+
"value": 45084768
|
| 2234 |
+
}
|
| 2235 |
+
}
|
| 2236 |
+
}
|
| 2237 |
+
}
|
| 2238 |
+
},
|
| 2239 |
+
"nbformat": 4,
|
| 2240 |
+
"nbformat_minor": 4
|
| 2241 |
+
}
|
pikapikagen/README.md
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: pikapikagen
|
| 3 |
+
app_file: gradio_demo.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 5.35.0
|
| 6 |
+
---
|
pikapikagen/__init__.py
ADDED
|
File without changes
|
pikapikagen/data_loader.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader, Subset
|
| 2 |
+
import torch
|
| 3 |
+
from dataset import PokemonDataset
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
def create_training_setup(
|
| 7 |
+
tokenizer,
|
| 8 |
+
test_set_size,
|
| 9 |
+
val_set_size,
|
| 10 |
+
batch_size,
|
| 11 |
+
num_workers=0,
|
| 12 |
+
num_viz_samples=4,
|
| 13 |
+
random_seed=42,
|
| 14 |
+
train_augmentation_pipeline=None,
|
| 15 |
+
):
|
| 16 |
+
"""
|
| 17 |
+
Create a complete setup for training with dataset, dataloaders and fixed batches for visualization.
|
| 18 |
+
"""
|
| 19 |
+
assert 0 <= test_set_size < 1.0, "test_set_size must be a float between 0 and 1"
|
| 20 |
+
assert 0 <= val_set_size < 1.0, "val_set_size must be a float between 0 and 1"
|
| 21 |
+
assert (test_set_size + val_set_size) < 1.0, "The sum of test and validation sizes must be less than 1"
|
| 22 |
+
|
| 23 |
+
train_full_dataset = PokemonDataset(tokenizer=tokenizer, augmentation_transforms=train_augmentation_pipeline)
|
| 24 |
+
# Don't use augmentation for test and validation
|
| 25 |
+
test_val_full_dataset = PokemonDataset(tokenizer=tokenizer)
|
| 26 |
+
|
| 27 |
+
dataset_size = len(train_full_dataset)
|
| 28 |
+
|
| 29 |
+
# Create a random reproducible permutation
|
| 30 |
+
generator = torch.Generator().manual_seed(random_seed)
|
| 31 |
+
shuffled_indices = torch.randperm(dataset_size, generator=generator)
|
| 32 |
+
|
| 33 |
+
val_count = math.ceil(val_set_size * dataset_size)
|
| 34 |
+
test_count = math.ceil(test_set_size * dataset_size)
|
| 35 |
+
train_count = dataset_size - val_count - test_count
|
| 36 |
+
|
| 37 |
+
# Partition based on the computed splits
|
| 38 |
+
train_indices = shuffled_indices[:train_count].tolist()
|
| 39 |
+
test_indices = shuffled_indices[train_count : train_count + test_count].tolist()
|
| 40 |
+
val_indices = shuffled_indices[train_count + test_count :].tolist()
|
| 41 |
+
|
| 42 |
+
# Create the subsets based on the indices
|
| 43 |
+
train_dataset = Subset(train_full_dataset, train_indices)
|
| 44 |
+
test_dataset = Subset(test_val_full_dataset, test_indices)
|
| 45 |
+
val_dataset = Subset(test_val_full_dataset, val_indices)
|
| 46 |
+
|
| 47 |
+
train_loader = DataLoader(
|
| 48 |
+
train_dataset,
|
| 49 |
+
batch_size=batch_size,
|
| 50 |
+
shuffle=True,
|
| 51 |
+
num_workers=num_workers,
|
| 52 |
+
pin_memory=True,
|
| 53 |
+
)
|
| 54 |
+
test_loader = DataLoader(
|
| 55 |
+
test_dataset,
|
| 56 |
+
batch_size=batch_size,
|
| 57 |
+
shuffle=False,
|
| 58 |
+
num_workers=num_workers,
|
| 59 |
+
pin_memory=True,
|
| 60 |
+
)
|
| 61 |
+
val_loader = DataLoader(
|
| 62 |
+
val_dataset,
|
| 63 |
+
batch_size=batch_size,
|
| 64 |
+
shuffle=False,
|
| 65 |
+
num_workers=num_workers,
|
| 66 |
+
pin_memory=True,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Batch for visualization
|
| 70 |
+
vis_generator = torch.Generator().manual_seed(random_seed)
|
| 71 |
+
|
| 72 |
+
fixed_train_batch = next(
|
| 73 |
+
iter(DataLoader(train_dataset, batch_size=num_viz_samples, shuffle=True, generator=vis_generator))
|
| 74 |
+
)
|
| 75 |
+
# Since no shuffle, a generator is not needed
|
| 76 |
+
fixed_test_batch = next(iter(DataLoader(test_dataset, batch_size=num_viz_samples, shuffle=False)))
|
| 77 |
+
fixed_val_batch = next(iter(DataLoader(val_dataset, batch_size=num_viz_samples, shuffle=False)))
|
| 78 |
+
|
| 79 |
+
# Batch (dimensione 1) for attention map visualization
|
| 80 |
+
vis_generator.manual_seed(random_seed)
|
| 81 |
+
fixed_train_attention_batch = next(
|
| 82 |
+
iter(DataLoader(train_dataset, batch_size=1, shuffle=True, generator=vis_generator))
|
| 83 |
+
)
|
| 84 |
+
fixed_test_attention_batch = next(iter(DataLoader(test_dataset, batch_size=1, shuffle=False)))
|
| 85 |
+
fixed_val_attention_batch = next(iter(DataLoader(val_dataset, batch_size=1, shuffle=False)))
|
| 86 |
+
|
| 87 |
+
return {
|
| 88 |
+
'train_loader': train_loader,
|
| 89 |
+
'val_loader': val_loader,
|
| 90 |
+
'test_loader': test_loader,
|
| 91 |
+
'train_dataset': train_dataset,
|
| 92 |
+
'val_dataset': val_dataset,
|
| 93 |
+
'test_dataset': test_dataset,
|
| 94 |
+
'fixed_train_batch': fixed_train_batch,
|
| 95 |
+
'fixed_val_batch': fixed_val_batch,
|
| 96 |
+
'fixed_test_batch': fixed_test_batch,
|
| 97 |
+
'fixed_train_attention_batch': fixed_train_attention_batch,
|
| 98 |
+
'fixed_val_attention_batch': fixed_val_attention_batch,
|
| 99 |
+
'fixed_test_attention_batch': fixed_test_attention_batch,
|
| 100 |
+
}
|
pikapikagen/dataset.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import urllib.request
|
| 3 |
+
import zipfile
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from typing import TypedDict
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PokemonSample(TypedDict):
|
| 14 |
+
text: torch.Tensor # Text already tokenized
|
| 15 |
+
image: torch.Tensor
|
| 16 |
+
description: str # Text before tokenization
|
| 17 |
+
pokemon_name: str
|
| 18 |
+
idx: int
|
| 19 |
+
attention_mask: torch.Tensor
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def reporthook(block_num, block_size, total_size):
|
| 23 |
+
if block_num % 16384 == 0:
|
| 24 |
+
print(f"Downloading... {block_num * block_size / (1024 * 1024):.2f} MB")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def download_dataset_if_not_exists():
|
| 28 |
+
dataset_dir = "dataset"
|
| 29 |
+
pokedex_main_dir = os.path.join(dataset_dir, "pokedex-main")
|
| 30 |
+
zip_url = "https://github.com/cristobalmitchell/pokedex/archive/refs/heads/main.zip"
|
| 31 |
+
zip_path = "pokedex_main.zip"
|
| 32 |
+
|
| 33 |
+
if os.path.exists(pokedex_main_dir):
|
| 34 |
+
print(f"{pokedex_main_dir} already exists. Skipping download.")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
os.makedirs(dataset_dir, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
print("Downloading dataset...")
|
| 40 |
+
urllib.request.urlretrieve(zip_url, zip_path, reporthook)
|
| 41 |
+
print("Download complete.")
|
| 42 |
+
|
| 43 |
+
print("Extracting dataset...")
|
| 44 |
+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
| 45 |
+
zip_ref.extractall(dataset_dir)
|
| 46 |
+
print("Extraction complete.")
|
| 47 |
+
|
| 48 |
+
os.remove(zip_path)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class PokemonDataset(Dataset):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
tokenizer,
|
| 55 |
+
csv_path="dataset/pokedex-main/data/pokemon.csv",
|
| 56 |
+
image_dir="dataset/pokedex-main/images/small_images",
|
| 57 |
+
max_length=128,
|
| 58 |
+
augmentation_transforms=None,
|
| 59 |
+
):
|
| 60 |
+
self.df = pd.read_csv(csv_path, encoding="utf-16 LE", delimiter="\t")
|
| 61 |
+
self.image_dir = Path(image_dir)
|
| 62 |
+
print(f"Dataset caricato: {len(self.df)} Pokemon con descrizioni e immagini")
|
| 63 |
+
|
| 64 |
+
self.tokenizer = tokenizer
|
| 65 |
+
self.max_length = max_length
|
| 66 |
+
|
| 67 |
+
if augmentation_transforms is not None:
|
| 68 |
+
self.final_transform = transforms.Compose(
|
| 69 |
+
[
|
| 70 |
+
transforms.ToTensor(),
|
| 71 |
+
transforms.Resize((256, 256), antialias=True),
|
| 72 |
+
augmentation_transforms,
|
| 73 |
+
transforms.Normalize(
|
| 74 |
+
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
|
| 75 |
+
), # Normalizza a [-1, 1]
|
| 76 |
+
]
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
self.final_transform = transforms.Compose(
|
| 80 |
+
[
|
| 81 |
+
transforms.ToTensor(),
|
| 82 |
+
transforms.Resize((256, 256), antialias=True),
|
| 83 |
+
transforms.Normalize(
|
| 84 |
+
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
|
| 85 |
+
), # Normalizza a [-1, 1]
|
| 86 |
+
]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def __len__(self):
|
| 90 |
+
return len(self.df)
|
| 91 |
+
|
| 92 |
+
def __getitem__(self, idx: int) -> PokemonSample:
|
| 93 |
+
# Ottieni la riga corrispondente
|
| 94 |
+
row = self.df.iloc[idx]
|
| 95 |
+
|
| 96 |
+
# === PREPROCESSING DEL TESTO ===
|
| 97 |
+
description = str(row["description"])
|
| 98 |
+
|
| 99 |
+
# Tokenizza il testo
|
| 100 |
+
encoded = self.tokenizer(
|
| 101 |
+
description,
|
| 102 |
+
max_length=self.max_length,
|
| 103 |
+
padding="max_length",
|
| 104 |
+
truncation=True,
|
| 105 |
+
return_tensors="pt",
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Estrai token_ids e attention_mask
|
| 109 |
+
text_ids = encoded["input_ids"].squeeze(0) # Rimuovi la dimensione batch
|
| 110 |
+
attention_mask = encoded["attention_mask"].squeeze(0)
|
| 111 |
+
|
| 112 |
+
# === CARICAMENTO E PREPROCESSING DELL'IMMAGINE ===
|
| 113 |
+
# Costruisce il percorso dell'immagine
|
| 114 |
+
image_filename = f"{row['national_number']:03d}.png"
|
| 115 |
+
image_path = self.image_dir / image_filename
|
| 116 |
+
|
| 117 |
+
# Carica l'immagine
|
| 118 |
+
image_rgba = Image.open(image_path).convert("RGBA")
|
| 119 |
+
|
| 120 |
+
# Gestisce la trasparenza: ricombina l'immagine con uno sfondo bianco
|
| 121 |
+
background = Image.new("RGB", image_rgba.size, (255, 255, 255))
|
| 122 |
+
background.paste(image_rgba, mask=image_rgba.split()[-1])
|
| 123 |
+
|
| 124 |
+
# Applica le trasformazioni finali (ToTensor, Resize, Normalize)
|
| 125 |
+
image_tensor = self.final_transform(background)
|
| 126 |
+
|
| 127 |
+
# Costruisce il risultato (matches pokemon_dataset.py structure)
|
| 128 |
+
sample = {
|
| 129 |
+
"text": text_ids,
|
| 130 |
+
"image": image_tensor,
|
| 131 |
+
"description": description,
|
| 132 |
+
"pokemon_name": row["english_name"],
|
| 133 |
+
"idx": idx,
|
| 134 |
+
"attention_mask": attention_mask,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
return sample
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
download_dataset_if_not_exists()
|
| 141 |
+
print("Dataset ready!")
|
pikapikagen/discriminators.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from model_blocks.text_encoder import TextEncoder
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Discriminator256(nn.Module):
|
| 7 |
+
def __init__(self, text_dim=256, img_channels=3):
|
| 8 |
+
super(Discriminator256, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.text_encoder = TextEncoder() # Separate text encoder for discriminators
|
| 11 |
+
|
| 12 |
+
self.img_path = nn.Sequential(
|
| 13 |
+
# 256x256 -> 128x128
|
| 14 |
+
nn.Conv2d(img_channels, 16, 4, 2, 1, bias=False),
|
| 15 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 16 |
+
|
| 17 |
+
# 128x128 -> 64x64
|
| 18 |
+
nn.Conv2d(16, 32, 4, 2, 1, bias=False),
|
| 19 |
+
nn.BatchNorm2d(32),
|
| 20 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 21 |
+
|
| 22 |
+
# 64x64 -> 32x32
|
| 23 |
+
nn.Conv2d(32, 64, 4, 2, 1, bias=False),
|
| 24 |
+
nn.BatchNorm2d(64),
|
| 25 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 26 |
+
|
| 27 |
+
# 32x32 -> 16x16
|
| 28 |
+
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
|
| 29 |
+
nn.BatchNorm2d(128),
|
| 30 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 31 |
+
|
| 32 |
+
# 16x16 -> 8x8
|
| 33 |
+
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
|
| 34 |
+
nn.BatchNorm2d(256),
|
| 35 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 36 |
+
|
| 37 |
+
# 8x8 -> 4x4
|
| 38 |
+
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
|
| 39 |
+
nn.BatchNorm2d(512),
|
| 40 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self.text_path = nn.Sequential(
|
| 44 |
+
nn.Linear(text_dim, 1024),
|
| 45 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 46 |
+
nn.Linear(1024, 512)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Unconditional classifier (real/fake without text conditioning)
|
| 50 |
+
self.unconditional_classifier = nn.Sequential(
|
| 51 |
+
nn.Linear(512 * 4 * 4, 1024),
|
| 52 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 53 |
+
nn.Dropout(0.5),
|
| 54 |
+
nn.Linear(1024, 1),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Conditional classifier (text-conditioned real/fake)
|
| 58 |
+
self.conditional_classifier = nn.Sequential(
|
| 59 |
+
nn.Linear(512 * 4 * 4 + 512, 1024), # size: sum of flattened image and text embedding
|
| 60 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 61 |
+
nn.Dropout(0.5),
|
| 62 |
+
nn.Linear(1024, 1),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def forward(self, images, text_features=None, text_mask=None, return_both=True):
|
| 66 |
+
# Encode image
|
| 67 |
+
img_features = self.img_path(images)
|
| 68 |
+
img_features_flat = img_features.view(img_features.size(0), -1) # Flatten
|
| 69 |
+
|
| 70 |
+
unconditional_output = self.unconditional_classifier(img_features_flat)
|
| 71 |
+
|
| 72 |
+
if not return_both:
|
| 73 |
+
return unconditional_output
|
| 74 |
+
|
| 75 |
+
if text_features is None or text_mask is None:
|
| 76 |
+
raise AttributeError("text_features and text_mask necessary for text conditioning")
|
| 77 |
+
|
| 78 |
+
# Encode text (mean pooling)
|
| 79 |
+
global_full_text = self.text_encoder(text_features, text_mask)
|
| 80 |
+
global_text = global_full_text.mean(dim=1)
|
| 81 |
+
text_features_encoded = self.text_path(global_text)
|
| 82 |
+
|
| 83 |
+
# Combine features
|
| 84 |
+
combined = torch.cat([img_features_flat, text_features_encoded], dim=1)
|
| 85 |
+
conditional_output = self.conditional_classifier(combined)
|
| 86 |
+
|
| 87 |
+
return unconditional_output, conditional_output
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Discriminator64(nn.Module):
|
| 91 |
+
def __init__(self, text_dim=256, img_channels=3):
|
| 92 |
+
super(Discriminator64, self).__init__()
|
| 93 |
+
|
| 94 |
+
self.text_encoder = TextEncoder()
|
| 95 |
+
|
| 96 |
+
self.img_path = nn.Sequential(
|
| 97 |
+
# 64x64 -> 32x32
|
| 98 |
+
nn.Conv2d(img_channels, 16, 4, 2, 1, bias=False),
|
| 99 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 100 |
+
|
| 101 |
+
# 32x32 -> 16x16
|
| 102 |
+
nn.Conv2d(16, 32, 4, 2, 1, bias=False),
|
| 103 |
+
nn.BatchNorm2d(32),
|
| 104 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 105 |
+
|
| 106 |
+
# 16x16 -> 8x8
|
| 107 |
+
nn.Conv2d(32, 64, 4, 2, 1, bias=False),
|
| 108 |
+
nn.BatchNorm2d(64),
|
| 109 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 110 |
+
|
| 111 |
+
# 8x8 -> 4x4
|
| 112 |
+
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
|
| 113 |
+
nn.BatchNorm2d(128),
|
| 114 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Text encoder for discriminator
|
| 118 |
+
self.text_path = nn.Sequential(
|
| 119 |
+
nn.Linear(text_dim, 1024),
|
| 120 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 121 |
+
nn.Linear(1024, 512)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Unconditional classifier (real/fake without text conditioning)
|
| 125 |
+
self.unconditional_classifier = nn.Sequential(
|
| 126 |
+
nn.Linear(128 * 4 * 4, 1024),
|
| 127 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 128 |
+
nn.Dropout(0.5),
|
| 129 |
+
nn.Linear(1024, 1),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Conditional classifier (text-conditioned real/fake)
|
| 133 |
+
self.conditional_classifier = nn.Sequential(
|
| 134 |
+
nn.Linear(128 * 4 * 4 + 512, 1024),
|
| 135 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 136 |
+
nn.Dropout(0.5),
|
| 137 |
+
nn.Linear(1024, 1),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def forward(self, images, text_features=None, text_mask=None, return_both=True):
|
| 141 |
+
img_features = self.img_path(images)
|
| 142 |
+
img_features_flat = img_features.view(img_features.size(0), -1) # Flatten
|
| 143 |
+
|
| 144 |
+
unconditional_output = self.unconditional_classifier(img_features_flat)
|
| 145 |
+
|
| 146 |
+
if not return_both:
|
| 147 |
+
return unconditional_output
|
| 148 |
+
|
| 149 |
+
if text_features is None or text_mask is None:
|
| 150 |
+
raise AttributeError("text_features and text_mask necessary for text conditioning")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Encode text (mean pooling)
|
| 154 |
+
global_full_text = self.text_encoder(text_features, text_mask)
|
| 155 |
+
global_text = global_full_text.mean(dim=1)
|
| 156 |
+
text_features_encoded = self.text_path(global_text)
|
| 157 |
+
|
| 158 |
+
combined = torch.cat([img_features_flat, text_features_encoded], dim=1)
|
| 159 |
+
conditional_output = self.conditional_classifier(combined)
|
| 160 |
+
|
| 161 |
+
return unconditional_output, conditional_output
|
pikapikagen/evaluate_kid.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
from model import Generator as PikaPikaGen
|
| 4 |
+
from data_loader import create_training_setup
|
| 5 |
+
from utils import denormalize_image
|
| 6 |
+
from torch_fidelity import calculate_metrics
|
| 7 |
+
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import shutil
|
| 11 |
+
|
| 12 |
+
CHECKPOINT_PATH = "pikapikagen/model_checkpoint/checkpoint_epoch_150.pth"
|
| 13 |
+
|
| 14 |
+
TOKENIZER_NAME = "prajjwal1/bert-mini"
|
| 15 |
+
|
| 16 |
+
BATCH_SIZE = 16 # Batch size for generating images
|
| 17 |
+
NUM_WORKERS = 2 # Number of workers for the data loader
|
| 18 |
+
|
| 19 |
+
KID_SUBSET_SIZE = 50
|
| 20 |
+
KID_NUM_SUBSETS = 20
|
| 21 |
+
|
| 22 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PokemonKIDEvaluator:
|
| 26 |
+
"""Evaluator class for computing KID metrics on PikaPikaGen."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, checkpoint_path, device=DEVICE):
|
| 29 |
+
self.device = device
|
| 30 |
+
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
| 31 |
+
self.checkpoint_path = checkpoint_path
|
| 32 |
+
|
| 33 |
+
self._load_model() # As in gradio demo
|
| 34 |
+
|
| 35 |
+
def _load_model(self):
|
| 36 |
+
self.generator = PikaPikaGen().to(self.device)
|
| 37 |
+
|
| 38 |
+
checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=True)
|
| 39 |
+
self.generator.load_state_dict(checkpoint['generator_state_dict'])
|
| 40 |
+
self.generator.eval()
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def _tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
|
| 44 |
+
denormalized = denormalize_image(tensor)
|
| 45 |
+
uint8_tensor = (denormalized * 255).clamp(0, 255).to(torch.uint8)
|
| 46 |
+
img_np = uint8_tensor.cpu().permute(1, 2, 0).numpy()
|
| 47 |
+
return Image.fromarray(img_np)
|
| 48 |
+
|
| 49 |
+
def _save_images_to_temp_dir(self, images_tensor: torch.Tensor, prefix: str) -> str:
|
| 50 |
+
"""Save a batch of image tensors to a new temporary directory."""
|
| 51 |
+
temp_dir = tempfile.mkdtemp(prefix=f"pikakid_{prefix}_")
|
| 52 |
+
for i, img_tensor in enumerate(images_tensor):
|
| 53 |
+
pil_img = self._tensor_to_pil(img_tensor)
|
| 54 |
+
img_path = os.path.join(temp_dir, f"{i:06d}.png")
|
| 55 |
+
pil_img.save(img_path)
|
| 56 |
+
return temp_dir
|
| 57 |
+
|
| 58 |
+
def evaluate_kid(self, test_loader, resolution="256x256"):
|
| 59 |
+
|
| 60 |
+
all_real_images = []
|
| 61 |
+
all_generated_images = []
|
| 62 |
+
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
for batch in test_loader:
|
| 65 |
+
text_ids = batch["text"].to(self.device)
|
| 66 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 67 |
+
real_images_256 = batch["image"] # (B, 3, 256, 256)
|
| 68 |
+
|
| 69 |
+
generated_256, generated_64 = self.generator(text_ids, attention_mask)
|
| 70 |
+
|
| 71 |
+
# Select the correct resolution for both real and generated images
|
| 72 |
+
if resolution == "256x256":
|
| 73 |
+
generated_images = generated_256
|
| 74 |
+
processed_real_images = real_images_256
|
| 75 |
+
elif resolution == "64x64":
|
| 76 |
+
generated_images = generated_64
|
| 77 |
+
processed_real_images = torch.nn.functional.interpolate(
|
| 78 |
+
real_images_256, size=(64, 64), mode='bilinear', align_corners=False
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
raise ValueError(f"Unsupported resolution: {resolution}")
|
| 82 |
+
|
| 83 |
+
all_real_images.append(processed_real_images.cpu())
|
| 84 |
+
all_generated_images.append(generated_images.cpu())
|
| 85 |
+
|
| 86 |
+
# Combine all batches into single tensors
|
| 87 |
+
all_real_images = torch.cat(all_real_images, dim=0)
|
| 88 |
+
all_generated_images = torch.cat(all_generated_images, dim=0)
|
| 89 |
+
|
| 90 |
+
# Save images to temporary directories for torch-fidelity
|
| 91 |
+
real_temp_dir = self._save_images_to_temp_dir(all_real_images, "real")
|
| 92 |
+
generated_temp_dir = self._save_images_to_temp_dir(all_generated_images, "generated")
|
| 93 |
+
|
| 94 |
+
metrics = calculate_metrics(
|
| 95 |
+
input1=generated_temp_dir, # Path to generated (fake) images
|
| 96 |
+
input2=real_temp_dir, # Path to real images
|
| 97 |
+
kid=True,
|
| 98 |
+
kid_subset_size=KID_SUBSET_SIZE,
|
| 99 |
+
kid_subsets=KID_NUM_SUBSETS,
|
| 100 |
+
batch_size=BATCH_SIZE,
|
| 101 |
+
device=self.device
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
kid_mean = metrics['kernel_inception_distance_mean']
|
| 105 |
+
kid_std = metrics['kernel_inception_distance_std']
|
| 106 |
+
|
| 107 |
+
# Clean up the temporary directories
|
| 108 |
+
shutil.rmtree(real_temp_dir)
|
| 109 |
+
shutil.rmtree(generated_temp_dir)
|
| 110 |
+
|
| 111 |
+
return kid_mean, kid_std
|
| 112 |
+
|
| 113 |
+
def main():
|
| 114 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
| 115 |
+
training_setup = create_training_setup(
|
| 116 |
+
tokenizer=tokenizer,
|
| 117 |
+
test_set_size=0.2,
|
| 118 |
+
val_set_size=0.1,
|
| 119 |
+
batch_size=BATCH_SIZE,
|
| 120 |
+
num_workers=NUM_WORKERS,
|
| 121 |
+
random_seed=42, # Use a fixed seed for a reproducible split
|
| 122 |
+
)
|
| 123 |
+
test_loader = training_setup['test_loader']
|
| 124 |
+
test_set_size = len(test_loader.dataset)
|
| 125 |
+
|
| 126 |
+
evaluator = PokemonKIDEvaluator(checkpoint_path=CHECKPOINT_PATH)
|
| 127 |
+
|
| 128 |
+
resolutions_to_test = ['64x64', '256x256']
|
| 129 |
+
|
| 130 |
+
print(f"Checkpoint: {CHECKPOINT_PATH}")
|
| 131 |
+
print(f"Test samples: {test_set_size}")
|
| 132 |
+
print(f"KID Subset Size: {KID_SUBSET_SIZE}")
|
| 133 |
+
print(f"KID Subsets: {KID_NUM_SUBSETS}")
|
| 134 |
+
|
| 135 |
+
for res in resolutions_to_test:
|
| 136 |
+
kid_mean, kid_std = evaluator.evaluate_kid(test_loader, resolution=res)
|
| 137 |
+
print(f"Resolution {res}:\t KID = {kid_mean:.6f} ± {kid_std:.6f}")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
main()
|
pikapikagen/gradio_demo.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import gradio.themes
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
from model import Generator as PikaPikaGen
|
| 8 |
+
from utils import denormalize_image
|
| 9 |
+
from plots import plot_attention_visualization
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 13 |
+
CHECKPOINT_PATH = "model_checkpoints/checkpoint_epoch_150.pth"
|
| 14 |
+
TOKENIZER_NAME = "prajjwal1/bert-mini"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PokemonGenerator:
|
| 18 |
+
"""Main class for the Pokemon generation demo"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.device = DEVICE
|
| 22 |
+
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
| 23 |
+
|
| 24 |
+
self._load_model()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _load_model(self):
|
| 28 |
+
"""Load the trained PikaPikaGen model"""
|
| 29 |
+
try:
|
| 30 |
+
# Initialize model
|
| 31 |
+
self.generator = PikaPikaGen().to(self.device)
|
| 32 |
+
|
| 33 |
+
# Load checkpoint
|
| 34 |
+
checkpoint = torch.load(CHECKPOINT_PATH, map_location=self.device, weights_only=True)
|
| 35 |
+
|
| 36 |
+
# Load saved weights into model
|
| 37 |
+
self.generator.load_state_dict(checkpoint['generator_state_dict'])
|
| 38 |
+
print(f"✅ Generator loaded from checkpoint (epoch {checkpoint.get('epoch', 'unknown')})")
|
| 39 |
+
|
| 40 |
+
# No training
|
| 41 |
+
self.generator.eval()
|
| 42 |
+
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"❌ Error loading model: {e}")
|
| 45 |
+
raise
|
| 46 |
+
|
| 47 |
+
def _tensor_to_pil(self, tensor):
|
| 48 |
+
"""Convert tensor to PIL Image"""
|
| 49 |
+
# tensor shape: (3, H, W)
|
| 50 |
+
img_np = tensor.permute(1, 2, 0).clamp(0, 1).numpy()
|
| 51 |
+
img_np = (img_np * 255).astype(np.uint8)
|
| 52 |
+
return Image.fromarray(img_np)
|
| 53 |
+
|
| 54 |
+
def generate_pokemon(self, description, num_samples=4, show_attention=False, resolution="both"):
|
| 55 |
+
"""
|
| 56 |
+
Generate Pokemon sprites from text description
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
description (str): Text description of the desired Pokemon
|
| 60 |
+
num_samples (int): Number of samples to generate (1-8)
|
| 61 |
+
show_attention (bool): Whether to show attention visualization
|
| 62 |
+
resolution (str): Output resolution - "256x256", "64x64", or "both"
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
tuple: (generated_images, attention_plot)
|
| 66 |
+
"""
|
| 67 |
+
if not description.strip():
|
| 68 |
+
return [], "❌ Please enter a description."
|
| 69 |
+
|
| 70 |
+
# No reason to compute gradients
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
tokens = self.tokenizer(
|
| 73 |
+
description,
|
| 74 |
+
max_length=128,
|
| 75 |
+
padding='max_length',
|
| 76 |
+
truncation=True,
|
| 77 |
+
return_tensors='pt'
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
text_ids = tokens['input_ids'].repeat(num_samples, 1).to(self.device)
|
| 81 |
+
attention_mask = tokens['attention_mask'].repeat(num_samples, 1).to(self.device)
|
| 82 |
+
|
| 83 |
+
generated_256, generated_64, attention_maps, initial_weights = self.generator(
|
| 84 |
+
text_ids, attention_mask, return_attentions=True
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Convert tensors to PIL images
|
| 88 |
+
output_images = []
|
| 89 |
+
images_to_process = []
|
| 90 |
+
if resolution in ["256x256", "both"]:
|
| 91 |
+
images_to_process.append(generated_256)
|
| 92 |
+
if resolution in ["64x64", "both"]:
|
| 93 |
+
images_to_process.append(generated_64)
|
| 94 |
+
|
| 95 |
+
for img_batch in images_to_process:
|
| 96 |
+
img_batch_denorm = denormalize_image(img_batch.cpu())
|
| 97 |
+
for i in range(num_samples):
|
| 98 |
+
img_pil = self._tensor_to_pil(img_batch_denorm[i])
|
| 99 |
+
output_images.append(img_pil)
|
| 100 |
+
|
| 101 |
+
attention_plot = None
|
| 102 |
+
if show_attention:
|
| 103 |
+
# Create directory if it doesn't exist
|
| 104 |
+
output_dir = "attention_visualizations"
|
| 105 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 106 |
+
|
| 107 |
+
# Create a more descriptive ID for the file
|
| 108 |
+
# To avoid overwriting the same file with the same name
|
| 109 |
+
plot_id = description.strip().replace(" ", "_")[:30]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Use the first sample for the attention visualization
|
| 113 |
+
attention_plot = plot_attention_visualization(
|
| 114 |
+
epoch=0,
|
| 115 |
+
set_name="demo",
|
| 116 |
+
output_dir=output_dir,
|
| 117 |
+
|
| 118 |
+
generated_images=generated_256,
|
| 119 |
+
|
| 120 |
+
# Full batch data from the model
|
| 121 |
+
decoder_attention_maps=attention_maps,
|
| 122 |
+
initial_context_weights=initial_weights,
|
| 123 |
+
|
| 124 |
+
token_ids=text_ids,
|
| 125 |
+
attention_mask=attention_mask,
|
| 126 |
+
tokenizer=self.tokenizer,
|
| 127 |
+
|
| 128 |
+
# Metadata for the specific sample
|
| 129 |
+
description=description,
|
| 130 |
+
pokemon_id=plot_id,
|
| 131 |
+
|
| 132 |
+
sample_idx=0,
|
| 133 |
+
show_inline=False
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return output_images, attention_plot
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
print("Initializing PikaPikaGen Demo...")
|
| 140 |
+
pokemon_gen = PokemonGenerator()
|
| 141 |
+
|
| 142 |
+
def generate_pokemon_interface(description, num_samples, show_attention, resolution):
|
| 143 |
+
images, attention_plot = pokemon_gen.generate_pokemon(
|
| 144 |
+
description=description,
|
| 145 |
+
num_samples=num_samples,
|
| 146 |
+
show_attention=show_attention,
|
| 147 |
+
resolution=resolution
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if images is None:
|
| 151 |
+
return [], attention_plot # attention_plot contains error message if error
|
| 152 |
+
|
| 153 |
+
status_msg = f"Generated {len(images)} Pokemon sprites"
|
| 154 |
+
if resolution == "both":
|
| 155 |
+
status_msg += f" ({num_samples} at 256x256 + {num_samples} at 64x64)"
|
| 156 |
+
else:
|
| 157 |
+
status_msg += f" at {resolution}"
|
| 158 |
+
|
| 159 |
+
return images, attention_plot
|
| 160 |
+
|
| 161 |
+
def create_interface():
|
| 162 |
+
with gr.Blocks(
|
| 163 |
+
title="PikaPikaGen: AI Pokemon Generator",
|
| 164 |
+
theme=gradio.themes.Soft(),
|
| 165 |
+
css="""
|
| 166 |
+
.main-header {
|
| 167 |
+
text-align: center;
|
| 168 |
+
background: linear-gradient(45deg, #ff6b6b, #4ecdc4);
|
| 169 |
+
-webkit-background-clip: text;
|
| 170 |
+
-webkit-text-fill-color: transparent;
|
| 171 |
+
font-size: 2.5em;
|
| 172 |
+
font-weight: bold;
|
| 173 |
+
margin-bottom: 0.5em;
|
| 174 |
+
}
|
| 175 |
+
.description {
|
| 176 |
+
text-align: center;
|
| 177 |
+
font-size: 1.1em;
|
| 178 |
+
color: #666;
|
| 179 |
+
margin-bottom: 1em;
|
| 180 |
+
}
|
| 181 |
+
"""
|
| 182 |
+
) as demo:
|
| 183 |
+
|
| 184 |
+
gr.HTML("""
|
| 185 |
+
<div class="main-header">🎮 PikaPikaGen: AI Pokemon Generator</div>
|
| 186 |
+
<div class="description">
|
| 187 |
+
Creation of Pokemon sprites from text descriptions using Transformer attention and CNN generation.
|
| 188 |
+
</div>
|
| 189 |
+
""")
|
| 190 |
+
|
| 191 |
+
with gr.Row():
|
| 192 |
+
with gr.Column(scale=1):
|
| 193 |
+
gr.Markdown("### 📝 Input")
|
| 194 |
+
|
| 195 |
+
description_input = gr.Textbox(
|
| 196 |
+
label="Pokemon Description",
|
| 197 |
+
placeholder="Describe your Pokemon! e.g., 'A fire dragon with golden scales and ruby eyes'",
|
| 198 |
+
lines=3,
|
| 199 |
+
value="A legendary fire dragon pokemon with golden scales and red eyes"
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
with gr.Row():
|
| 203 |
+
num_samples = gr.Slider(
|
| 204 |
+
minimum=1, maximum=8, value=4, step=1,
|
| 205 |
+
label="Number of samples"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
resolution = gr.Radio(
|
| 209 |
+
choices=["256x256", "64x64", "both"],
|
| 210 |
+
value="256x256",
|
| 211 |
+
label="Output resolution"
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
show_attention = gr.Checkbox(
|
| 215 |
+
label="Show attention visualization",
|
| 216 |
+
value=True,
|
| 217 |
+
info="Visualize which words the model focuses on"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
generate_btn = gr.Button(
|
| 221 |
+
"🎨 Generate Pokemon!",
|
| 222 |
+
variant="primary",
|
| 223 |
+
size="lg"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
with gr.Column(scale=2):
|
| 227 |
+
gr.Markdown("### 🎨 Generated Pokemon")
|
| 228 |
+
|
| 229 |
+
output_gallery = gr.Gallery(
|
| 230 |
+
label="Generated Pokemon sprites",
|
| 231 |
+
show_label=True,
|
| 232 |
+
elem_id="gallery",
|
| 233 |
+
columns=2,
|
| 234 |
+
rows=2,
|
| 235 |
+
height="auto",
|
| 236 |
+
allow_preview=True
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
attention_output = gr.Image(
|
| 240 |
+
label="Attention visualization",
|
| 241 |
+
show_label=True,
|
| 242 |
+
interactive=False
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Examples section
|
| 246 |
+
gr.Markdown("### 🌟 Examples to try")
|
| 247 |
+
gr.Examples(
|
| 248 |
+
examples=[
|
| 249 |
+
["A fire dragon with golden scales and red eyes", 4, True, "256x256"],
|
| 250 |
+
["An electric mouse with yellow fur and lightning bolts", 3, False, "both"],
|
| 251 |
+
["A water turtle with blue shell and powerful jaws", 2, True, "256x256"],
|
| 252 |
+
["A psychic cat with purple fur and mystical powers", 4, True, "256x256"],
|
| 253 |
+
["A grass serpent with emerald scales and vine whips", 3, False, "64x64"],
|
| 254 |
+
["An ice phoenix with crystal wings and frozen flames", 4, True, "256x256"],
|
| 255 |
+
["A dark wolf with shadow abilities and glowing eyes", 2, True, "both"],
|
| 256 |
+
["A steel robot pokemon with metallic armor and laser beams", 3, False, "256x256"]
|
| 257 |
+
],
|
| 258 |
+
inputs=[description_input, num_samples, show_attention, resolution],
|
| 259 |
+
outputs=[output_gallery, attention_output],
|
| 260 |
+
fn=generate_pokemon_interface,
|
| 261 |
+
cache_examples=False
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Event handlers
|
| 265 |
+
generate_btn.click(
|
| 266 |
+
fn=generate_pokemon_interface,
|
| 267 |
+
inputs=[description_input, num_samples, show_attention, resolution],
|
| 268 |
+
outputs=[output_gallery, attention_output]
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Footer
|
| 272 |
+
gr.Markdown("""
|
| 273 |
+
---
|
| 274 |
+
**PikaPikaGen** - Text-to-Image Pokemon Generation using Transformer + CNN
|
| 275 |
+
""")
|
| 276 |
+
|
| 277 |
+
return demo
|
| 278 |
+
|
| 279 |
+
if __name__ == "__main__":
|
| 280 |
+
print("Starting PikaPikaGen Demo...")
|
| 281 |
+
|
| 282 |
+
# Create and launch interface
|
| 283 |
+
demo = create_interface()
|
| 284 |
+
|
| 285 |
+
demo.launch(
|
| 286 |
+
server_name="0.0.0.0", # Allow external access
|
| 287 |
+
share=False, # Set to True for public sharing
|
| 288 |
+
debug=False,
|
| 289 |
+
show_error=True,
|
| 290 |
+
inbrowser=True # Auto-open browser
|
| 291 |
+
)
|
pikapikagen/losses.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision import models
|
| 5 |
+
from torchvision.models import VGG19_Weights
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class VGGPerceptualLoss(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Perceptual loss using VGG19 pretrained on ImageNet.
|
| 11 |
+
We extract features at:
|
| 12 |
+
- relu1_2 (index: 3)
|
| 13 |
+
- relu2_2 (index: 8)
|
| 14 |
+
- relu3_2 (index: 17)
|
| 15 |
+
- relu4_2 (index: 26)
|
| 16 |
+
Then compute L1 distance between those feature maps.
|
| 17 |
+
Input images are in [-1,1]. We convert to [0,1], then normalize with ImageNet stats.
|
| 18 |
+
"""
|
| 19 |
+
def __init__(self, device):
|
| 20 |
+
super(VGGPerceptualLoss, self).__init__()
|
| 21 |
+
vgg19_features = models.vgg19(weights=VGG19_Weights.DEFAULT).features.to(device).eval()
|
| 22 |
+
# We only need layers up to 26 (relu4_2)
|
| 23 |
+
self.slices = nn.ModuleDict({
|
| 24 |
+
"relu1_2": nn.Sequential(*list(vgg19_features.children())[:4]), # conv1_1, relu1_1, conv1_2, relu1_2
|
| 25 |
+
"relu2_2": nn.Sequential(*list(vgg19_features.children())[4:9]), # pool1, conv2_1, relu2_1, conv2_2, relu2_2
|
| 26 |
+
"relu3_2": nn.Sequential(*list(vgg19_features.children())[9:18]), # pool2, conv3_1, relu3_1, conv3_2, relu3_2, ...
|
| 27 |
+
"relu4_2": nn.Sequential(*list(vgg19_features.children())[18:27]) # pool3, conv4_1, relu4_1, conv4_2, relu4_2
|
| 28 |
+
})
|
| 29 |
+
for param in self.parameters():
|
| 30 |
+
param.requires_grad = False
|
| 31 |
+
|
| 32 |
+
self.l1 = nn.L1Loss()
|
| 33 |
+
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1))
|
| 34 |
+
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1))
|
| 35 |
+
|
| 36 |
+
def forward(self, img_gen, img_ref):
|
| 37 |
+
"""
|
| 38 |
+
img_gen, img_ref: [B,3,H,W] in range [-1,1].
|
| 39 |
+
Return: sum of L1 distances between VGG feature maps at chosen layers.
|
| 40 |
+
"""
|
| 41 |
+
# Convert to [0,1]
|
| 42 |
+
gen = (img_gen + 1.0) / 2.0
|
| 43 |
+
ref = (img_ref + 1.0) / 2.0
|
| 44 |
+
# Normalize
|
| 45 |
+
gen_norm = (gen - self.mean) / self.std
|
| 46 |
+
ref_norm = (ref - self.mean) / self.std
|
| 47 |
+
|
| 48 |
+
loss = 0.0
|
| 49 |
+
x_gen = gen_norm
|
| 50 |
+
x_ref = ref_norm
|
| 51 |
+
for slice_mod in self.slices.values():
|
| 52 |
+
x_gen = slice_mod(x_gen)
|
| 53 |
+
x_ref = slice_mod(x_ref)
|
| 54 |
+
loss += self.l1(x_gen, x_ref)
|
| 55 |
+
return loss
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class SobelLoss(nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
Computes the Sobel loss between two images, which encourages edge similarity.
|
| 61 |
+
This loss operates on the grayscale versions of the input images.
|
| 62 |
+
"""
|
| 63 |
+
def __init__(self):
|
| 64 |
+
super(SobelLoss, self).__init__()
|
| 65 |
+
# Sobel kernels for edge detection
|
| 66 |
+
self.kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
| 67 |
+
self.kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
| 68 |
+
self.l1 = nn.L1Loss()
|
| 69 |
+
|
| 70 |
+
# Grayscale conversion weights (ITU-R BT.601)
|
| 71 |
+
self.rgb_to_gray_weights = torch.tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
| 72 |
+
|
| 73 |
+
def _get_edges(self, img):
|
| 74 |
+
"""
|
| 75 |
+
Converts an RGB image to grayscale and applies Sobel filters.
|
| 76 |
+
Args:
|
| 77 |
+
img: [B, 3, H, W] image tensor in range [-1, 1].
|
| 78 |
+
Returns:
|
| 79 |
+
Gradient magnitude map [B, 1, H, W].
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
# Convert from [-1, 1] to [0, 1]
|
| 83 |
+
img = (img + 1.0) / 2.0
|
| 84 |
+
|
| 85 |
+
# Convert to grayscale
|
| 86 |
+
grayscale_img = F.conv2d(img, self.rgb_to_gray_weights.to(img.device))
|
| 87 |
+
|
| 88 |
+
# Apply Sobel filters
|
| 89 |
+
grad_x = F.conv2d(grayscale_img, self.kernel_x.to(img.device), padding=1)
|
| 90 |
+
grad_y = F.conv2d(grayscale_img, self.kernel_y.to(img.device), padding=1)
|
| 91 |
+
|
| 92 |
+
# Compute gradient magnitude
|
| 93 |
+
edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) # add epsilon for stability
|
| 94 |
+
return edges
|
| 95 |
+
|
| 96 |
+
def forward(self, img_gen, img_ref):
|
| 97 |
+
"""
|
| 98 |
+
img_gen, img_ref: [B, 3, H, W] in range [-1, 1].
|
| 99 |
+
Returns: L1 loss between the edge maps of the two images.
|
| 100 |
+
"""
|
| 101 |
+
edges_gen = self._get_edges(img_gen)
|
| 102 |
+
edges_ref = self._get_edges(img_ref)
|
| 103 |
+
return self.l1(edges_gen, edges_ref)
|
pikapikagen/model.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from model_blocks.text_encoder import TextEncoder
|
| 4 |
+
from model_blocks.image_decoder import ImageDecoder
|
| 5 |
+
|
| 6 |
+
class Generator(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Modello completo che unisce Encoder e Decoder.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, text_encoder_model_name="prajjwal1/bert-mini", noise_dim=100):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.text_encoder = TextEncoder(
|
| 13 |
+
model_name=text_encoder_model_name,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
text_embed_dim = 256
|
| 17 |
+
|
| 18 |
+
self.image_decoder = ImageDecoder(
|
| 19 |
+
noise_dim=noise_dim,
|
| 20 |
+
text_embed_dim=text_embed_dim
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
self.noise_dim = noise_dim
|
| 24 |
+
|
| 25 |
+
def forward(self, token_ids, attention_mask, return_attentions=False):
|
| 26 |
+
# token_ids.shape: (batch_size, seq_len)
|
| 27 |
+
# attention_mask.shape: (batch_size, seq_len)
|
| 28 |
+
# Genera rumore casuale per il batch
|
| 29 |
+
batch_size = token_ids.size(0)
|
| 30 |
+
# noise.shape: (batch_size, noise_dim)
|
| 31 |
+
noise = torch.randn(batch_size, self.noise_dim, device=token_ids.device)
|
| 32 |
+
|
| 33 |
+
# 1. Codifica il testo per ottenere i vettori di ogni parola
|
| 34 |
+
# encoder_output.shape: (batch_size, seq_len, text_embed_dim)
|
| 35 |
+
encoder_output = self.text_encoder(token_ids, attention_mask=attention_mask)
|
| 36 |
+
|
| 37 |
+
# 2. Genera l'immagine usando l'output completo dell'encoder
|
| 38 |
+
# Il decoder calcolerà internamente sia il contesto iniziale (ATTENZIONE #1)
|
| 39 |
+
# sia l'attenzione per-step (ATTENZIONE #2)
|
| 40 |
+
# generated_image_256.shape: (batch_size, 3, 256, 256)
|
| 41 |
+
# generated_image_64.shape: (batch_size, 3, 64, 64)
|
| 42 |
+
generated_image_256, generated_image_64, attention_maps, initial_attention_weights = self.image_decoder(noise, encoder_output, attention_mask)
|
| 43 |
+
|
| 44 |
+
if return_attentions:
|
| 45 |
+
return generated_image_256, generated_image_64, attention_maps, initial_attention_weights
|
| 46 |
+
return generated_image_256, generated_image_64
|
pikapikagen/model_blocks/decoder_block.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from model_blocks.image_cross_attention import ImageCrossAttention
|
| 4 |
+
|
| 5 |
+
class DecoderBlock(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Image decoder block
|
| 8 |
+
Channel adaptation (if necessary) -> Attention (optional) -> Merge -> Residual connection
|
| 9 |
+
-> Upsampling (ConvTranspose) -> Normalization -> Activation.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, in_channels, out_channels, use_attention=True, text_embed_dim=256, nhead=4):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.use_attention = use_attention
|
| 14 |
+
|
| 15 |
+
if self.use_attention:
|
| 16 |
+
# If in_channels is different from text_embed_dim, add a 1x1 conv to adapt the channel size
|
| 17 |
+
if in_channels != text_embed_dim:
|
| 18 |
+
self.channel_adapter = nn.Conv2d(in_channels, text_embed_dim, kernel_size=1, bias=False)
|
| 19 |
+
else:
|
| 20 |
+
self.channel_adapter = None
|
| 21 |
+
|
| 22 |
+
self.cross_attention = ImageCrossAttention(embed_dim=text_embed_dim, num_heads=nhead)
|
| 23 |
+
# Convolution to merge the text_embedding and the cross-attention output
|
| 24 |
+
self.fusion_conv = nn.Conv2d(text_embed_dim * 2, in_channels, kernel_size=1, bias=False)
|
| 25 |
+
|
| 26 |
+
# Upsample block as described in the instructions
|
| 27 |
+
self.upsample_block = nn.Sequential(
|
| 28 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
|
| 29 |
+
nn.GroupNorm(1, out_channels),
|
| 30 |
+
nn.LeakyReLU(inplace=True)
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def forward(self, x, encoder_output=None, attention_mask=None):
|
| 34 |
+
attn_weights = None
|
| 35 |
+
if self.use_attention:
|
| 36 |
+
if encoder_output is None or attention_mask is None:
|
| 37 |
+
raise ValueError("encoder_output and attention_mask must be provided for attention.")
|
| 38 |
+
|
| 39 |
+
# Adapt channel size if deemed necessary
|
| 40 |
+
if self.channel_adapter is not None:
|
| 41 |
+
x_adapted = self.channel_adapter(x)
|
| 42 |
+
else:
|
| 43 |
+
x_adapted = x
|
| 44 |
+
|
| 45 |
+
attn_output, attn_weights = self.cross_attention(
|
| 46 |
+
image_features=x_adapted,
|
| 47 |
+
text_features=encoder_output,
|
| 48 |
+
key_padding_mask=attention_mask
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Concatenates the features with the cross-attention output,
|
| 52 |
+
# then conv 1x1 and residual connection
|
| 53 |
+
fused_features = torch.cat([x_adapted, attn_output], dim=1) # Shape: (B, 2*in_channels, H, W)
|
| 54 |
+
skip = self.fusion_conv(fused_features) # Shape: (B, in_channels, H, W)
|
| 55 |
+
x = x + skip # Shape: (B, in_channels, H, W)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
x = self.upsample_block(x)
|
| 59 |
+
return x, attn_weights
|
pikapikagen/model_blocks/image_cross_attention.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
class ImageCrossAttention(nn.Module):
|
| 4 |
+
"""
|
| 5 |
+
Image cross-attention module
|
| 6 |
+
Allows a sequence of queries (from the image) to "pay attention"
|
| 7 |
+
to a sequence of key/value (from the text), internally managing
|
| 8 |
+
the reshaping of tensors and the attention mask.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, embed_dim, num_heads):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.attention = nn.MultiheadAttention(
|
| 13 |
+
embed_dim=embed_dim, num_heads=num_heads, batch_first=True
|
| 14 |
+
)
|
| 15 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
| 16 |
+
|
| 17 |
+
def forward(self, image_features, text_features, key_padding_mask=None):
|
| 18 |
+
# query: (B, C, H, W) - Image features
|
| 19 |
+
# key/value: (B, seq_len, embed_dim) - Text encoder output
|
| 20 |
+
# key_padding_mask: (B, seq_len) - Attention mask from the tokenizer
|
| 21 |
+
|
| 22 |
+
B, C, H, W = image_features.shape
|
| 23 |
+
|
| 24 |
+
# Reshape from image to sequence: (B, C, H, W) -> (B, H*W, C)
|
| 25 |
+
query_seq = image_features.view(B, C, H * W).permute(0, 2, 1)
|
| 26 |
+
query_norm = self.layer_norm(query_seq)
|
| 27 |
+
|
| 28 |
+
# Prepare the padding mask from the attention mask
|
| 29 |
+
# The HuggingFace mask is 1 for real tokens, 0 for padding.
|
| 30 |
+
# MultiheadAttention expects True for positions to ignore.
|
| 31 |
+
if key_padding_mask is not None:
|
| 32 |
+
mask = (key_padding_mask == 0)
|
| 33 |
+
else:
|
| 34 |
+
mask = None
|
| 35 |
+
|
| 36 |
+
attn_output, attn_weights = self.attention(
|
| 37 |
+
query=query_norm,
|
| 38 |
+
key=text_features,
|
| 39 |
+
value=text_features,
|
| 40 |
+
key_padding_mask=mask,
|
| 41 |
+
need_weights=True
|
| 42 |
+
)
|
| 43 |
+
# attn_output: (B, H*W, C)
|
| 44 |
+
|
| 45 |
+
# Convert output back into its original size
|
| 46 |
+
# (B, H*W, C) -> (B, C, H*W) -> (B, C, H, W)
|
| 47 |
+
attn_output_spatial = attn_output.permute(0, 2, 1).view(B, C, H, W)
|
| 48 |
+
|
| 49 |
+
return attn_output_spatial, attn_weights
|
pikapikagen/model_blocks/image_decoder.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from model_blocks.decoder_block import DecoderBlock
|
| 4 |
+
|
| 5 |
+
class ImageDecoder(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Decoder CNN (Generatore) che sintetizza l'immagine.
|
| 8 |
+
Questa versione usa l'attenzione per-step fin dall'inizio.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, noise_dim, text_embed_dim, final_image_channels=3):
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
# Mechanism to calculate attention scores for the initial context.
|
| 14 |
+
self.initial_context_scorer = nn.Sequential(
|
| 15 |
+
nn.Linear(in_features=text_embed_dim, out_features=512),
|
| 16 |
+
nn.Tanh(),
|
| 17 |
+
nn.Linear(in_features=512, out_features=1)
|
| 18 |
+
# Softmax applied in forward pass to use the attention mask
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Initial linear projection to a 4x4 feature map.
|
| 22 |
+
self.initial_projection = nn.Sequential(
|
| 23 |
+
nn.Linear(noise_dim + text_embed_dim, 256 * 4 * 4),
|
| 24 |
+
nn.GroupNorm(1, 256 * 4 * 4),
|
| 25 |
+
nn.LeakyReLU(inplace=True)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Shared blocks for both resolutions (until 64x64)
|
| 29 |
+
self.blocks_64 = nn.ModuleList([
|
| 30 |
+
# Input: (B, 256, 4, 4) -> Output: (B, 256, 8, 8)
|
| 31 |
+
DecoderBlock(in_channels=256, out_channels=256, use_attention=True),
|
| 32 |
+
# Input: (B, 256, 8, 8) -> Output: (B, 256, 16, 16)
|
| 33 |
+
DecoderBlock(in_channels=256, out_channels=256, use_attention=True),
|
| 34 |
+
# Input: (B, 256, 16, 16) -> Output: (B, 128, 32, 32)
|
| 35 |
+
DecoderBlock(in_channels=256, out_channels=128, use_attention=True),
|
| 36 |
+
# Input: (B, 128, 32, 32) -> Output: (B, 64, 64, 64)
|
| 37 |
+
DecoderBlock(in_channels=128, out_channels=64, use_attention=False),
|
| 38 |
+
])
|
| 39 |
+
|
| 40 |
+
# ModuleList is used instead of a Sequential for example because
|
| 41 |
+
# of the branching based on use_attention in the forward pass
|
| 42 |
+
|
| 43 |
+
# Blocks only for 256x256 (from 64x64 to 256x256)
|
| 44 |
+
self.blocks_256 = nn.ModuleList([
|
| 45 |
+
# Input: (B, 64, 64, 64) -> Output: (B, 32, 128, 128)
|
| 46 |
+
DecoderBlock(in_channels=64, out_channels=32, use_attention=True),
|
| 47 |
+
# Input: (B, 32, 128, 128) -> Output: (B, 16, 256, 256)
|
| 48 |
+
DecoderBlock(in_channels=32, out_channels=16, use_attention=False),
|
| 49 |
+
])
|
| 50 |
+
|
| 51 |
+
# Last layer to get to RGB channels - 256x256
|
| 52 |
+
# Input: (B, 16, 256, 256) -> Output: (B, 3, 256, 256)
|
| 53 |
+
self.final_conv_256 = nn.Conv2d(16, final_image_channels, kernel_size=3, padding=1)
|
| 54 |
+
self.final_activation_256 = nn.Tanh()
|
| 55 |
+
|
| 56 |
+
# Last layer to get to RGB channels - 64x64
|
| 57 |
+
# Input: (B, 64, 64, 64) -> Output: (B, 3, 64, 64)
|
| 58 |
+
self.final_conv_64 = nn.Conv2d(64, final_image_channels, kernel_size=3, padding=1)
|
| 59 |
+
self.final_activation_64 = nn.Tanh()
|
| 60 |
+
|
| 61 |
+
def forward(self, noise, encoder_output_full, attention_mask):
|
| 62 |
+
# noise.shape: (B, noise_dim)
|
| 63 |
+
# encoder_output_full.shape: (B, seq_len, text_embed_dim)
|
| 64 |
+
# attention_mask.shape: (B, seq_len)
|
| 65 |
+
|
| 66 |
+
# 1. Compute the first attention, with the scores (logits) for each token
|
| 67 |
+
attn_scores = self.initial_context_scorer(encoder_output_full)
|
| 68 |
+
|
| 69 |
+
# Apply attention mask before Softmax.
|
| 70 |
+
# Set the scores of the padding tokens, where attention mask is 0, to -inf.
|
| 71 |
+
# The mask is (B, seq_len), the scores (B, seq_len, 1)
|
| 72 |
+
# The unsqueeze takes care of the dimension diference.
|
| 73 |
+
attn_scores.masked_fill_(attention_mask.unsqueeze(-1) == 0, -1e9)
|
| 74 |
+
|
| 75 |
+
# attention_weights.shape: (B, seq_len, 1)
|
| 76 |
+
attention_weights = torch.softmax(attn_scores, dim=1)
|
| 77 |
+
|
| 78 |
+
# Weighted average of the encoder output
|
| 79 |
+
# context_vector.shape: (B, text_embed_dim)
|
| 80 |
+
context_vector = torch.sum(attention_weights * encoder_output_full, dim=1)
|
| 81 |
+
|
| 82 |
+
# 2. Merge the noise and the context vector for the initial projection
|
| 83 |
+
# initial_input.shape: (B, noise_dim + text_embed_dim)
|
| 84 |
+
initial_input = torch.cat([noise, context_vector], dim=1)
|
| 85 |
+
|
| 86 |
+
# 3. Initial projection and reshape to fit the transposed convolutions
|
| 87 |
+
# x.shape: (B, 256 * 4 * 4)
|
| 88 |
+
x = self.initial_projection(initial_input)
|
| 89 |
+
# x.shape: (B, 256, 4, 4)
|
| 90 |
+
x = x.view(x.size(0), 256, 4, 4)
|
| 91 |
+
|
| 92 |
+
# 4. Pass through the encoder blocks
|
| 93 |
+
attention_maps = []
|
| 94 |
+
|
| 95 |
+
# Shared path for both resolutions (fino a 64x64)
|
| 96 |
+
for block in self.blocks_64:
|
| 97 |
+
encoder_ctx = encoder_output_full if block.use_attention else None
|
| 98 |
+
mask_ctx = attention_mask if block.use_attention else None
|
| 99 |
+
x, attn_weights = block(x, encoder_ctx, mask_ctx)
|
| 100 |
+
if attn_weights is not None:
|
| 101 |
+
attention_maps.append(attn_weights)
|
| 102 |
+
|
| 103 |
+
# Now x has size (B, 64, 64, 64)
|
| 104 |
+
|
| 105 |
+
# 64x64-only path
|
| 106 |
+
image_64 = self.final_conv_64(x)
|
| 107 |
+
image_64 = self.final_activation_64(image_64)
|
| 108 |
+
|
| 109 |
+
# 5. 256x256-only path
|
| 110 |
+
for block in self.blocks_256:
|
| 111 |
+
encoder_ctx = encoder_output_full if block.use_attention else None
|
| 112 |
+
mask_ctx = attention_mask if block.use_attention else None
|
| 113 |
+
x, attn_weights = block(x, encoder_ctx, mask_ctx)
|
| 114 |
+
if attn_weights is not None:
|
| 115 |
+
attention_maps.append(attn_weights)
|
| 116 |
+
|
| 117 |
+
# Final layer for 256x256
|
| 118 |
+
# x_256.shape: (B, 16, 256, 256) -> (B, 3, 256, 256)
|
| 119 |
+
image_256 = self.final_conv_256(x)
|
| 120 |
+
image_256 = self.final_activation_256(image_256)
|
| 121 |
+
|
| 122 |
+
return image_256, image_64, attention_maps, attention_weights
|
pikapikagen/model_blocks/text_encoder.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from transformers import AutoModel
|
| 3 |
+
|
| 4 |
+
class TextEncoder(nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Text encoder
|
| 7 |
+
Uses bert-mini embeddings and passes them through a Transformer.
|
| 8 |
+
"""
|
| 9 |
+
def __init__(self, model_name="prajjwal1/bert-mini", fine_tune_embeddings=True):
|
| 10 |
+
super().__init__()
|
| 11 |
+
# Load the pre-trained bert-mini model for embeddings
|
| 12 |
+
bert_mini_model = AutoModel.from_pretrained(model_name)
|
| 13 |
+
|
| 14 |
+
self.embedding = bert_mini_model.embeddings
|
| 15 |
+
|
| 16 |
+
# Set whether to fine-tune the embeddings during training
|
| 17 |
+
for param in self.embedding.parameters():
|
| 18 |
+
param.requires_grad = fine_tune_embeddings
|
| 19 |
+
|
| 20 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 21 |
+
d_model=256, nhead=4, dim_feedforward=1024, batch_first=True
|
| 22 |
+
)
|
| 23 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
| 24 |
+
|
| 25 |
+
def forward(self, token_ids, attention_mask=None):
|
| 26 |
+
# Get the embeddings from the tokens
|
| 27 |
+
# Shape: (batch_size, seq_len) -> (batch_size, seq_len, embedding_dim)
|
| 28 |
+
embedded_text = self.embedding(token_ids)
|
| 29 |
+
|
| 30 |
+
# Prepare the padding mask for TransformerEncoder
|
| 31 |
+
# The HuggingFace mask is 1 for real tokens, 0 for padding.
|
| 32 |
+
# TransformerEncoder expects True for positions to ignore (padding).
|
| 33 |
+
src_key_padding_mask = None
|
| 34 |
+
if attention_mask is not None:
|
| 35 |
+
src_key_padding_mask = (attention_mask == 0)
|
| 36 |
+
|
| 37 |
+
# Pass the embeddings through the Transformer Encoder with the mask
|
| 38 |
+
# Shape: (batch_size, seq_len, embedding_dim) -> (batch_size, seq_len, embedding_dim)
|
| 39 |
+
encoder_output = self.transformer_encoder(
|
| 40 |
+
src=embedded_text,
|
| 41 |
+
src_key_padding_mask=src_key_padding_mask
|
| 42 |
+
)
|
| 43 |
+
return encoder_output
|
pikapikagen/model_checkpoint/checkpoint_epoch_150.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7902ac75581c4a54ec5345ccf2bd30440a99a4c1031b5c12af6cabb318dde225
|
| 3 |
+
size 789795998
|
pikapikagen/plots.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import os
|
| 5 |
+
import io
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from utils import denormalize_image
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def save_attention_visualization(
|
| 13 |
+
epoch, model, tokenizer, batch, device, set_name, output_dir, show_inline=False
|
| 14 |
+
):
|
| 15 |
+
print(f"Epoch {epoch}: Generating attention visualization for {set_name} set...")
|
| 16 |
+
|
| 17 |
+
attention_data = generate_attention_data(model, tokenizer, batch, device)
|
| 18 |
+
|
| 19 |
+
if attention_data:
|
| 20 |
+
plot_attention_visualization(
|
| 21 |
+
epoch=epoch,
|
| 22 |
+
set_name=set_name,
|
| 23 |
+
output_dir=output_dir,
|
| 24 |
+
show_inline=show_inline,
|
| 25 |
+
**attention_data,
|
| 26 |
+
)
|
| 27 |
+
print(f"Epoch {epoch}: Attention visualization saved for Pokémon #{attention_data['pokemon_id']}.")
|
| 28 |
+
else:
|
| 29 |
+
print(f"Epoch {epoch}: Skipped attention visualization due to missing data.")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def generate_attention_data(model, tokenizer, batch, device):
|
| 33 |
+
"""
|
| 34 |
+
Runs the model to generate the image and attention maps, filtering the padding tokens.
|
| 35 |
+
"""
|
| 36 |
+
model.eval()
|
| 37 |
+
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
token_ids = batch["text"].to(device)
|
| 40 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 41 |
+
# Ensure batch size is 1 for visualization
|
| 42 |
+
if token_ids.dim() > 1:
|
| 43 |
+
token_ids = token_ids[0].unsqueeze(0)
|
| 44 |
+
attention_mask = attention_mask[0].unsqueeze(0)
|
| 45 |
+
|
| 46 |
+
# Get the first sample from the batch
|
| 47 |
+
pokemon_id = batch["idx"][0]
|
| 48 |
+
description = batch["description"][0]
|
| 49 |
+
|
| 50 |
+
generated_image, attention_maps, initial_context_weights = model(
|
| 51 |
+
token_ids, attention_mask, return_attentions=True
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
decoder_attention_maps = [m for m in attention_maps if m is not None]
|
| 55 |
+
|
| 56 |
+
if not decoder_attention_maps or initial_context_weights is None:
|
| 57 |
+
print("Attention maps not available. Skipping data generation.")
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
# Extract valid tokens to display
|
| 61 |
+
tokens_all = tokenizer.convert_ids_to_tokens(token_ids.squeeze(0))
|
| 62 |
+
display_tokens = []
|
| 63 |
+
for i, token in enumerate(tokens_all):
|
| 64 |
+
if (
|
| 65 |
+
token not in [tokenizer.sep_token, tokenizer.pad_token]
|
| 66 |
+
and attention_mask[0, i] == 1
|
| 67 |
+
):
|
| 68 |
+
display_tokens.append({"token": token, "index": i})
|
| 69 |
+
|
| 70 |
+
if not display_tokens:
|
| 71 |
+
print(f"No valid tokens to display for '{description}'. Skipping.")
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
return {
|
| 75 |
+
"generated_image": generated_image.cpu(),
|
| 76 |
+
"decoder_attention_maps": [m.cpu() for m in decoder_attention_maps],
|
| 77 |
+
"initial_context_weights": initial_context_weights.cpu(),
|
| 78 |
+
"display_tokens": display_tokens,
|
| 79 |
+
"description": description,
|
| 80 |
+
"pokemon_id": pokemon_id,
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def plot_attention_visualization(
|
| 85 |
+
# Plot identification arguments
|
| 86 |
+
epoch: int,
|
| 87 |
+
set_name: str,
|
| 88 |
+
output_dir: str | None,
|
| 89 |
+
# Data generated by the model (can be full batches)
|
| 90 |
+
generated_images: torch.Tensor,
|
| 91 |
+
decoder_attention_maps: list[torch.Tensor],
|
| 92 |
+
initial_context_weights: torch.Tensor,
|
| 93 |
+
# Original text input (can be a full batch)
|
| 94 |
+
token_ids: torch.Tensor,
|
| 95 |
+
attention_mask: torch.Tensor,
|
| 96 |
+
tokenizer: AutoTokenizer,
|
| 97 |
+
# Batch metadata (for the specific sample)
|
| 98 |
+
description: str,
|
| 99 |
+
pokemon_id: int | str,
|
| 100 |
+
# Control options
|
| 101 |
+
sample_idx: int = 0,
|
| 102 |
+
show_inline: bool = False,
|
| 103 |
+
):
|
| 104 |
+
"""
|
| 105 |
+
Generates and saves an attention visualization for a single sample from a batch.
|
| 106 |
+
|
| 107 |
+
This function is self-contained: it accepts full batch tensors and internally
|
| 108 |
+
handles sample selection and token preparation.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
epoch (int): Epoch number (for title/filename).
|
| 112 |
+
set_name (str): Set name (e.g., 'train', for title/filename).
|
| 113 |
+
output_dir (str, optional): Folder to save the image. If None, the plot is not saved.
|
| 114 |
+
|
| 115 |
+
generated_images (torch.Tensor): Tensor of generated images.
|
| 116 |
+
Shape: (B, C, H, W).
|
| 117 |
+
decoder_attention_maps (list[torch.Tensor]): List of attention tensors.
|
| 118 |
+
Each tensor shape: (B, num_patches, seq_length).
|
| 119 |
+
initial_context_weights (torch.Tensor): Initial attention weights.
|
| 120 |
+
Shape: (B, 1, seq_length).
|
| 121 |
+
|
| 122 |
+
token_ids (torch.Tensor): Input token.
|
| 123 |
+
Shape: (B, seq_length).
|
| 124 |
+
attention_mask (torch.Tensor): Attention mask for tokens.
|
| 125 |
+
Shape: (B, seq_length).
|
| 126 |
+
tokenizer: The tokenizer object for id -> token conversion.
|
| 127 |
+
|
| 128 |
+
description (str): The text prompt for the selected sample.
|
| 129 |
+
pokemon_id (int or str): The ID of the selected sample.
|
| 130 |
+
|
| 131 |
+
sample_idx (int, optional): Index of the sample in the batch to visualize.
|
| 132 |
+
Defaults to 0.
|
| 133 |
+
show_inline (bool, optional): If True, shows the plot. Defaults to False.
|
| 134 |
+
"""
|
| 135 |
+
# Select the specific sample using sample_idx and move to CPU
|
| 136 |
+
img_tensor = generated_images[sample_idx].cpu()
|
| 137 |
+
layer_maps = [m[sample_idx].cpu() for m in decoder_attention_maps if m is not None]
|
| 138 |
+
initial_weights = initial_context_weights[sample_idx].cpu()
|
| 139 |
+
token_ids_sample = token_ids[sample_idx].cpu()
|
| 140 |
+
attention_mask_sample = attention_mask[sample_idx].cpu()
|
| 141 |
+
|
| 142 |
+
# Token filtering logic
|
| 143 |
+
tokens_all = tokenizer.convert_ids_to_tokens(token_ids_sample)
|
| 144 |
+
display_tokens = []
|
| 145 |
+
for i, token in enumerate(tokens_all):
|
| 146 |
+
if (
|
| 147 |
+
token not in [tokenizer.sep_token, tokenizer.pad_token]
|
| 148 |
+
and attention_mask_sample[i] == 1
|
| 149 |
+
):
|
| 150 |
+
display_tokens.append({"token": token, "index": i})
|
| 151 |
+
|
| 152 |
+
img_tensor_cpu = denormalize_image(img_tensor).permute(1, 2, 0)
|
| 153 |
+
num_decoder_layers = len(layer_maps)
|
| 154 |
+
num_tokens = len(display_tokens)
|
| 155 |
+
token_indices_to_display = [t["index"] for t in display_tokens]
|
| 156 |
+
|
| 157 |
+
cols = min(num_tokens, 8)
|
| 158 |
+
rows_per_layer = (num_tokens + cols - 1) // cols
|
| 159 |
+
height_ratios = [3, 2] + [2 * rows_per_layer] * num_decoder_layers
|
| 160 |
+
fig_height = sum(height_ratios)
|
| 161 |
+
fig_width = max(20, 2.5 * cols)
|
| 162 |
+
|
| 163 |
+
fig = plt.figure(figsize=(fig_width, fig_height))
|
| 164 |
+
gs_main = fig.add_gridspec(len(height_ratios), 1, height_ratios=height_ratios, hspace=1.2)
|
| 165 |
+
fig.suptitle(f"Epoch {epoch}: Attention for Pokémon #{pokemon_id} ({set_name.capitalize()})", fontsize=24)
|
| 166 |
+
|
| 167 |
+
ax_main_img = fig.add_subplot(gs_main[0])
|
| 168 |
+
ax_main_img.imshow(img_tensor_cpu)
|
| 169 |
+
ax_main_img.set_title("Generated Image", fontsize=18)
|
| 170 |
+
ax_main_img.text(0.5, -0.1, f"Prompt: {description}", ha="center", va="top",
|
| 171 |
+
transform=ax_main_img.transAxes, fontsize=14, wrap=True)
|
| 172 |
+
ax_main_img.axis("off")
|
| 173 |
+
|
| 174 |
+
ax_initial_attn = fig.add_subplot(gs_main[1])
|
| 175 |
+
initial_weights_squeezed = initial_weights.squeeze().numpy()
|
| 176 |
+
token_strings = [t["token"] for t in display_tokens]
|
| 177 |
+
relevant_weights = initial_weights_squeezed[[t["index"] for t in display_tokens]]
|
| 178 |
+
ax_initial_attn.bar(np.arange(len(token_strings)), relevant_weights, color="skyblue")
|
| 179 |
+
ax_initial_attn.set_xticks(np.arange(len(token_strings)))
|
| 180 |
+
ax_initial_attn.set_xticklabels(token_strings, rotation=45, ha="right", fontsize=10)
|
| 181 |
+
ax_initial_attn.set_title("Initial Context Attention (Global)", fontsize=16)
|
| 182 |
+
ax_initial_attn.set_ylabel("Weight", fontsize=12)
|
| 183 |
+
ax_initial_attn.grid(axis="y", linestyle="--", alpha=0.7)
|
| 184 |
+
|
| 185 |
+
# Iterate through each decoder layer's attention maps
|
| 186 |
+
for i, layer_attn_map in enumerate(layer_maps):
|
| 187 |
+
# layer_attn_map shape is now (num_patches, seq_len)
|
| 188 |
+
map_size_flat = layer_attn_map.shape[0]
|
| 189 |
+
map_side = int(np.sqrt(map_size_flat))
|
| 190 |
+
layer_title = f"Decoder Cross-Attention Layer {i+1} (Size: {map_side}x{map_side})"
|
| 191 |
+
|
| 192 |
+
# Extract attention weights only for tokens we want to display
|
| 193 |
+
relevant_attn_maps = layer_attn_map[:, token_indices_to_display]
|
| 194 |
+
vmin, vmax = relevant_attn_maps.min(), relevant_attn_maps.max()
|
| 195 |
+
|
| 196 |
+
# Create subplot grid for this layer
|
| 197 |
+
gs_layer = gs_main[2 + i].subgridspec(rows_per_layer, cols + 1, wspace=0.2, hspace=0.4, width_ratios=[*([1] * cols), 0.1])
|
| 198 |
+
axes_in_layer = [fig.add_subplot(gs_layer[r, c]) for r in range(rows_per_layer) for c in range(cols)]
|
| 199 |
+
|
| 200 |
+
# Add layer title above the token attention maps
|
| 201 |
+
if axes_in_layer:
|
| 202 |
+
y_pos = axes_in_layer[0].get_position().y1
|
| 203 |
+
fig.text(0.5, y_pos + 0.01, layer_title, ha="center", va="bottom", fontsize=16, weight="bold")
|
| 204 |
+
|
| 205 |
+
# Plot attention heatmap for each token
|
| 206 |
+
im = None
|
| 207 |
+
for j, token_info in enumerate(display_tokens):
|
| 208 |
+
if j >= len(axes_in_layer):
|
| 209 |
+
break
|
| 210 |
+
ax = axes_in_layer[j]
|
| 211 |
+
attn_for_token = layer_attn_map[:, token_info["index"]]
|
| 212 |
+
# Reshape flat attention to spatial grid
|
| 213 |
+
heatmap = attn_for_token.reshape(map_side, map_side)
|
| 214 |
+
im = ax.imshow(heatmap, cmap="jet", interpolation="nearest", vmin=vmin, vmax=vmax)
|
| 215 |
+
ax.set_title(f"'{token_info['token']}'", fontsize=12)
|
| 216 |
+
ax.axis("off")
|
| 217 |
+
|
| 218 |
+
# Add colorbar for the layer
|
| 219 |
+
if im:
|
| 220 |
+
cax = fig.add_subplot(gs_layer[:, -1])
|
| 221 |
+
cbar = fig.colorbar(im, cax=cax)
|
| 222 |
+
cbar.ax.tick_params(labelsize=10)
|
| 223 |
+
cbar.set_label("Attention Weight", rotation=270, labelpad=15, fontsize=12)
|
| 224 |
+
|
| 225 |
+
# Hide unused subplots
|
| 226 |
+
for j in range(num_tokens, len(axes_in_layer)):
|
| 227 |
+
axes_in_layer[j].axis("off")
|
| 228 |
+
|
| 229 |
+
plt.tight_layout(rect=(0, 0.03, 1, 0.96))
|
| 230 |
+
if output_dir is not None:
|
| 231 |
+
save_path = os.path.join(output_dir, f"{epoch:03d}_{set_name}_attention_visualization_{pokemon_id}.png")
|
| 232 |
+
plt.savefig(save_path, bbox_inches="tight")
|
| 233 |
+
|
| 234 |
+
# Save figure to bytes for potential further use (e.g., logging)
|
| 235 |
+
buf = io.BytesIO()
|
| 236 |
+
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
|
| 237 |
+
buf.seek(0)
|
| 238 |
+
|
| 239 |
+
# Convert to PIL image
|
| 240 |
+
attention_plot = Image.open(buf)
|
| 241 |
+
|
| 242 |
+
if show_inline:
|
| 243 |
+
plt.show()
|
| 244 |
+
plt.close(fig)
|
| 245 |
+
|
| 246 |
+
return attention_plot
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def save_plot_losses(losses_g, losses_d, output_dir="training_output", show_inline=True):
|
| 250 |
+
"""
|
| 251 |
+
Generates and saves a plot of the generator and discriminator losses.
|
| 252 |
+
"""
|
| 253 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 254 |
+
|
| 255 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 256 |
+
ax.plot(losses_g, label="Generator Loss", color="blue")
|
| 257 |
+
ax.plot(losses_d, label="Discriminator Loss", color="red")
|
| 258 |
+
ax.set_title("Training Losses")
|
| 259 |
+
ax.set_xlabel("Epochs")
|
| 260 |
+
ax.set_ylabel("Loss")
|
| 261 |
+
ax.legend()
|
| 262 |
+
ax.grid(True)
|
| 263 |
+
|
| 264 |
+
save_path = os.path.join(output_dir, "training_losses.png")
|
| 265 |
+
plt.savefig(save_path)
|
| 266 |
+
print(f"Loss plot saved to: {save_path}")
|
| 267 |
+
|
| 268 |
+
if show_inline:
|
| 269 |
+
plt.show()
|
| 270 |
+
else:
|
| 271 |
+
plt.close(fig)
|
| 272 |
+
|
| 273 |
+
def save_plot_non_gan_losses(train_losses_history, val_losses_history, output_dir="training_output", show_inline=True, filter_losses=None):
|
| 274 |
+
"""
|
| 275 |
+
Generates and saves plots of losses for non-GAN models with multiple loss components.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
train_losses_history (list[dict]): List of dicts containing training losses per epoch.
|
| 279 |
+
e.g., [{'l1': 0.5, 'sobel': 0.3}, ...]
|
| 280 |
+
val_losses_history (list[dict]): List of dicts containing validation losses per epoch.
|
| 281 |
+
output_dir (str): Directory to save the plot.
|
| 282 |
+
show_inline (bool): Whether to display the plot inline.
|
| 283 |
+
filter_losses (list[str], optional): List of loss names to plot.
|
| 284 |
+
If None, plots all found losses.
|
| 285 |
+
"""
|
| 286 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 287 |
+
|
| 288 |
+
# Extract all unique loss keys from both training and validation
|
| 289 |
+
all_keys = set()
|
| 290 |
+
for losses_dict in train_losses_history + val_losses_history:
|
| 291 |
+
all_keys.update(losses_dict.keys())
|
| 292 |
+
|
| 293 |
+
# Filter out non-numeric keys if any
|
| 294 |
+
loss_keys = [key for key in all_keys if key not in ['epoch']]
|
| 295 |
+
|
| 296 |
+
# Apply filter if specified
|
| 297 |
+
if filter_losses is not None:
|
| 298 |
+
loss_keys = [key for key in loss_keys if key in filter_losses]
|
| 299 |
+
|
| 300 |
+
loss_keys = sorted(loss_keys) # Sort for consistent ordering
|
| 301 |
+
|
| 302 |
+
# Create subplots
|
| 303 |
+
n_losses = len(loss_keys)
|
| 304 |
+
cols = min(3, n_losses) # Max 3 columns
|
| 305 |
+
rows = (n_losses + cols - 1) // cols # Ceiling division
|
| 306 |
+
|
| 307 |
+
fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
|
| 308 |
+
if n_losses == 1:
|
| 309 |
+
axes = [axes]
|
| 310 |
+
elif rows > 1:
|
| 311 |
+
axes = axes.flatten()
|
| 312 |
+
|
| 313 |
+
fig.suptitle("Training and Validation Losses", fontsize=16, y=0.98)
|
| 314 |
+
|
| 315 |
+
for i, loss_key in enumerate(loss_keys):
|
| 316 |
+
ax = axes[i]
|
| 317 |
+
|
| 318 |
+
# Extract train and validation losses for this key
|
| 319 |
+
train_values = [losses.get(loss_key, 0) for losses in train_losses_history]
|
| 320 |
+
val_values = [losses.get(loss_key, 0) for losses in val_losses_history]
|
| 321 |
+
|
| 322 |
+
epochs_train = range(1, len(train_values) + 1)
|
| 323 |
+
epochs_val = range(1, len(val_values) + 1)
|
| 324 |
+
|
| 325 |
+
# Plot training and validation curves
|
| 326 |
+
if train_values:
|
| 327 |
+
ax.plot(epochs_train, train_values, label=f"Train {loss_key}", color="blue", linewidth=1.5)
|
| 328 |
+
if val_values:
|
| 329 |
+
ax.plot(epochs_val, val_values, label=f"Val {loss_key}", color="red", linewidth=1.5, linestyle='--')
|
| 330 |
+
|
| 331 |
+
ax.set_title(f"{loss_key.capitalize()} Loss", fontsize=12)
|
| 332 |
+
ax.set_xlabel("Epoch")
|
| 333 |
+
ax.set_ylabel("Loss")
|
| 334 |
+
ax.legend()
|
| 335 |
+
ax.grid(True, alpha=0.3)
|
| 336 |
+
ax.set_ylim(bottom=0)
|
| 337 |
+
|
| 338 |
+
# Hide unused subplots
|
| 339 |
+
for i in range(n_losses, len(axes)):
|
| 340 |
+
axes[i].set_visible(False)
|
| 341 |
+
|
| 342 |
+
plt.tight_layout()
|
| 343 |
+
|
| 344 |
+
# Save the plot
|
| 345 |
+
save_path = os.path.join(output_dir, "non_gan_training_losses.png")
|
| 346 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 347 |
+
print(f"Non-GAN training losses plot saved to: {save_path}")
|
| 348 |
+
|
| 349 |
+
if show_inline:
|
| 350 |
+
plt.show()
|
| 351 |
+
else:
|
| 352 |
+
plt.close(fig)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def save_comparison_grid(epoch, model, batch, set_name, device, output_dir="training_output", show_inline=True):
|
| 356 |
+
"""
|
| 357 |
+
Generates and saves/shows a horizontal comparison grid (real vs. generated).
|
| 358 |
+
Automatically handles 256x256 or 64x64 output based on set_name.
|
| 359 |
+
"""
|
| 360 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 361 |
+
|
| 362 |
+
model.eval()
|
| 363 |
+
token_ids = batch["text"].to(device)
|
| 364 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 365 |
+
real_images = batch["image"]
|
| 366 |
+
pokemon_ids = batch["idx"]
|
| 367 |
+
descriptions = batch["description"]
|
| 368 |
+
num_images = real_images.size(0)
|
| 369 |
+
|
| 370 |
+
with torch.no_grad():
|
| 371 |
+
generated_images = model(token_ids, attention_mask)
|
| 372 |
+
# Handle tuple output from generator (e.g., 256px and 64px images)
|
| 373 |
+
if isinstance(generated_images, tuple):
|
| 374 |
+
# Check if we want 64x64 or 256x256 based on set_name
|
| 375 |
+
if "64" in set_name:
|
| 376 |
+
generated_images = generated_images[1] # Use 64x64 output
|
| 377 |
+
# Resize real images to 64x64 for comparison
|
| 378 |
+
real_images = F.interpolate(real_images, size=(64, 64), mode='bilinear', align_corners=False)
|
| 379 |
+
else:
|
| 380 |
+
generated_images = generated_images[0] # Use 256x256 output
|
| 381 |
+
|
| 382 |
+
fig, axs = plt.subplots(2, num_images, figsize=(4 * num_images, 8.5))
|
| 383 |
+
resolution = "64x64" if "64" in set_name else "256x256"
|
| 384 |
+
fig.suptitle(
|
| 385 |
+
f"Epoch {epoch} - {set_name.capitalize()} Comparison ({resolution})", fontsize=16, y=0.98
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
for i in range(num_images):
|
| 389 |
+
ax_real = axs[0, i]
|
| 390 |
+
ax_real.imshow(denormalize_image(real_images[i].cpu()).permute(1, 2, 0))
|
| 391 |
+
ax_real.set_title(f"#{pokemon_ids[i]}: {descriptions[i][:35]}...", fontsize=10)
|
| 392 |
+
ax_real.axis("off")
|
| 393 |
+
|
| 394 |
+
ax_gen = axs[1, i]
|
| 395 |
+
ax_gen.imshow(denormalize_image(generated_images[i].cpu()).permute(1, 2, 0))
|
| 396 |
+
ax_gen.axis("off")
|
| 397 |
+
|
| 398 |
+
axs[0, 0].text(
|
| 399 |
+
-0.1,
|
| 400 |
+
0.5,
|
| 401 |
+
"Real",
|
| 402 |
+
ha="center",
|
| 403 |
+
va="center",
|
| 404 |
+
rotation="vertical",
|
| 405 |
+
fontsize=14,
|
| 406 |
+
transform=axs[0, 0].transAxes,
|
| 407 |
+
)
|
| 408 |
+
axs[1, 0].text(
|
| 409 |
+
-0.1,
|
| 410 |
+
0.5,
|
| 411 |
+
"Generated",
|
| 412 |
+
ha="center",
|
| 413 |
+
va="center",
|
| 414 |
+
rotation="vertical",
|
| 415 |
+
fontsize=14,
|
| 416 |
+
transform=axs[1, 0].transAxes,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
plt.tight_layout(rect=(0, 0, 1, 0.95))
|
| 420 |
+
|
| 421 |
+
# Save the figure and optionally show it
|
| 422 |
+
save_path = os.path.join(output_dir, f"{epoch:03d}_{set_name}_comparison.png")
|
| 423 |
+
plt.savefig(save_path)
|
| 424 |
+
|
| 425 |
+
if show_inline:
|
| 426 |
+
plt.show()
|
| 427 |
+
else:
|
| 428 |
+
plt.close(fig)
|
pikapikagen/utils.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def denormalize_image(tensor):
|
| 2 |
+
"""
|
| 3 |
+
Denormalizza un tensore immagine dall'intervallo [-1, 1] a [0, 1] per la visualizzazione.
|
| 4 |
+
|
| 5 |
+
Args:
|
| 6 |
+
tensor (torch.Tensor): Il tensore dell'immagine, con valori in [-1, 1].
|
| 7 |
+
|
| 8 |
+
Returns:
|
| 9 |
+
torch.Tensor: Il tensore denormalizzato con valori in [0, 1].
|
| 10 |
+
"""
|
| 11 |
+
tensor = (tensor + 1) / 2
|
| 12 |
+
return tensor.clamp(0, 1)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "pikapikagen"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"gradio>=5.35.0",
|
| 9 |
+
"ipykernel>=6.29.5",
|
| 10 |
+
"ipywidgets>=8.1.7",
|
| 11 |
+
"jupyterlab>=4.4.5",
|
| 12 |
+
"matplotlib>=3.10.3",
|
| 13 |
+
"pandas>=2.3.0",
|
| 14 |
+
"sentence-transformers>=5.0.0",
|
| 15 |
+
"torch-fidelity>=0.3.0",
|
| 16 |
+
"torchvision>=0.22.1",
|
| 17 |
+
"transformers>=4.53.0",
|
| 18 |
+
]
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|