Add files using upload-large-folder tool
Browse files- OFFICIAL_UVDoc_对标说明.txt +41 -0
- UVDoc_official/.gitignore +136 -0
- UVDoc_official/LICENSE +21 -0
- UVDoc_official/README.md +160 -0
- UVDoc_official/__pycache__/data_UVDoc.cpython-310.pyc +0 -0
- UVDoc_official/__pycache__/data_UVDoc.cpython-313.pyc +0 -0
- UVDoc_official/__pycache__/data_custom_augmentations.cpython-310.pyc +0 -0
- UVDoc_official/__pycache__/data_doc3D.cpython-312.pyc +0 -0
- UVDoc_official/__pycache__/data_utils.cpython-313.pyc +0 -0
- UVDoc_official/compute_uvdoc_grid3d_stats.py +73 -0
- UVDoc_official/data/readme.txt +1 -0
- UVDoc_official/data_UVDoc.py +232 -0
- UVDoc_official/data_custom_augmentations.py +148 -0
- UVDoc_official/data_doc3D.py +95 -0
- UVDoc_official/data_mixDataset.py +23 -0
- UVDoc_official/data_utils.py +175 -0
- UVDoc_official/demo.py +55 -0
- UVDoc_official/docUnet_eval.py +106 -0
- UVDoc_official/docUnet_pred.py +179 -0
- UVDoc_official/model.py +374 -0
- UVDoc_official/requirements_demo.txt +3 -0
- UVDoc_official/requirements_eval.txt +10 -0
- UVDoc_official/requirements_train.txt +5 -0
- UVDoc_official/run_official_overfit_train_infer.sh +101 -0
- UVDoc_official/train.py +552 -0
- UVDoc_official/utils.py +66 -0
- UVDoc_official/uvdocBenchmark_eval.py +129 -0
- UVDoc_official/uvdocBenchmark_metric.py +152 -0
- UVDoc_official/uvdocBenchmark_pred.py +131 -0
- UVDoc_official/verify_ckpt_val_pipeline.py +153 -0
- UVDoc_official/verify_uvdoc_train_infer_preprocess.py +169 -0
- baseline_resnet_unet/__init__.py +5 -0
- baseline_resnet_unet/dataset.py +197 -0
- baseline_resnet_unet/model.py +89 -0
- baseline_resnet_unet/train.py +187 -0
- baseline_resnet_unet/warp.py +24 -0
- log_full_uvdoc_gpu0.bak_20260411_122217/nohup.out +9 -0
- log_full_uvdoc_gpu0.bak_20260411_122217/params8_lr=0.0002_nepochs50_nepochsdecay20_alpha5.0_beta5.0_gamma=1.0_gammastartep10_datauvdoc.txt +2 -0
- log_full_uvdoc_gpu0/nohup.out +0 -0
- log_full_uvdoc_gpu0/params8_lr=0.0002_nepochs25_nepochsdecay10_alpha5.0_beta5.0_gamma=1.0_gammastartep10_datauvdoc.txt +111 -0
- log_full_uvdoc_gpu0/verify_val_ep12_infer/metrics.txt +1001 -0
- requirements_baseline.txt +6 -0
- requirements_uvdoc_train.txt +9 -0
- run_overfit_official_uvdoc.sh +27 -0
- run_overfit_train_infer_consistency.sh +75 -0
- run_train_full_uvdoc_gpu0.sh +40 -0
- run_train_official_config.sh +80 -0
- run_train_uvdoc_baseline.py +11 -0
- unzip_extract.log +1 -0
- uvdoc_文档矫正_colab_技术路线(gemini_可执行版).md +212 -0
OFFICIAL_UVDoc_对标说明.txt
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
【官方 UVDoc 代码(已克隆 + 小补丁)】
|
| 2 |
+
路径: UvDoc/UVDoc_official/
|
| 3 |
+
上游: https://github.com/tanguymagne/UVDoc
|
| 4 |
+
|
| 5 |
+
【与论文/仓库一致的默认设定】
|
| 6 |
+
- 模型: model.UVDocnet(稀疏 grid2D + grid3D,utils.bilinear_unwarping + grid_sample)
|
| 7 |
+
- 损失: alpha * L1(grid2D) + beta * L1(grid3D) + gamma * L1(重建),gamma 从 ep_gamma_start 起启用(与 train.py 一致)
|
| 8 |
+
- 官方默认数据模式: --data_to_use both(Doc3D + UVDoc 混合训练);验证集为 Doc3D 的 val(与原版一致)
|
| 9 |
+
|
| 10 |
+
【严格对标时请准备】
|
| 11 |
+
1) UVDoc_final 解压目录(含 img/ grid2d/ grid3d/ metadata_sample/ 等)
|
| 12 |
+
2) Doc3D 数据 + 作者提供的 Doc3D_grid(见官方 README 链接)
|
| 13 |
+
|
| 14 |
+
【仅 UVDoc_final、无 Doc3D 时(扩展模式)】
|
| 15 |
+
本目录对官方 train.py 增加了 data_to_use=uvdoc:
|
| 16 |
+
- 在同一套 UVDoc 上按样本 id 做 train/val 划分(默认 val_ratio=0.05,可改)
|
| 17 |
+
- 验证仍用重建 MSE(与原版 val 形式一致,但数据域是 UVDoc 而非 Doc3D)
|
| 18 |
+
注意:这与论文「Doc3D 上 val」不完全相同,仅便于本地先跑通官方网络与损失。
|
| 19 |
+
|
| 20 |
+
【运行示例】
|
| 21 |
+
cd /mnt/zsn/zsn_workspace/dzx/UvDoc/UVDoc_official
|
| 22 |
+
|
| 23 |
+
# A) 官方默认(需 Doc3D + UVDoc)
|
| 24 |
+
python train.py --data_to_use both \
|
| 25 |
+
--data_path_doc3D /path/to/data/doc3D/ \
|
| 26 |
+
--data_path_UVDoc /mnt/zsn/zsn_workspace/dzx/UvDoc/UVDoc_final \
|
| 27 |
+
--logdir ./log/uvdoc_official
|
| 28 |
+
|
| 29 |
+
# B) 仅 UVDoc(无 Doc3D)
|
| 30 |
+
python train.py --data_to_use uvdoc \
|
| 31 |
+
--data_path_UVDoc /mnt/zsn/zsn_workspace/dzx/UvDoc/UVDoc_final \
|
| 32 |
+
--logdir ./log/uvdoc_only
|
| 33 |
+
|
| 34 |
+
# 依赖见 requirements_train.txt(版本较老;若你环境已是 torch 2.x,多数情况可直接试跑)
|
| 35 |
+
|
| 36 |
+
【评估 / 推理】
|
| 37 |
+
仓库内 demo、uvdocBenchmark_pred.py、docUnet_pred.py 等与上游一致,checkpoint 键名 model_state。
|
| 38 |
+
|
| 39 |
+
【与 baseline_resnet_unet 的区别】
|
| 40 |
+
- baseline_resnet_unet: ResNet50+UNet 密集 UV,技术笔记里的简化路线
|
| 41 |
+
- UVDoc_official: 与 SIGGRAPH Asia 论文实现一致的网络与监督(grid2D+grid3D)
|
UVDoc_official/.gitignore
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
target/
|
| 76 |
+
|
| 77 |
+
# Jupyter Notebook
|
| 78 |
+
.ipynb_checkpoints
|
| 79 |
+
|
| 80 |
+
# IPython
|
| 81 |
+
profile_default/
|
| 82 |
+
ipython_config.py
|
| 83 |
+
|
| 84 |
+
# pyenv
|
| 85 |
+
.python-version
|
| 86 |
+
|
| 87 |
+
# pipenv
|
| 88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 91 |
+
# install all needed dependencies.
|
| 92 |
+
#Pipfile.lock
|
| 93 |
+
|
| 94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 95 |
+
__pypackages__/
|
| 96 |
+
|
| 97 |
+
# Celery stuff
|
| 98 |
+
celerybeat-schedule
|
| 99 |
+
celerybeat.pid
|
| 100 |
+
|
| 101 |
+
# SageMath parsed files
|
| 102 |
+
*.sage.py
|
| 103 |
+
|
| 104 |
+
# Environments
|
| 105 |
+
.env
|
| 106 |
+
.venv
|
| 107 |
+
env/
|
| 108 |
+
venv/
|
| 109 |
+
ENV/
|
| 110 |
+
env.bak/
|
| 111 |
+
venv.bak/
|
| 112 |
+
|
| 113 |
+
# Spyder project settings
|
| 114 |
+
.spyderproject
|
| 115 |
+
.spyproject
|
| 116 |
+
|
| 117 |
+
# Rope project settings
|
| 118 |
+
.ropeproject
|
| 119 |
+
|
| 120 |
+
# mkdocs documentation
|
| 121 |
+
/site
|
| 122 |
+
|
| 123 |
+
# mypy
|
| 124 |
+
.mypy_cache/
|
| 125 |
+
.dmypy.json
|
| 126 |
+
dmypy.json
|
| 127 |
+
|
| 128 |
+
# Pyre type checker
|
| 129 |
+
.pyre/
|
| 130 |
+
|
| 131 |
+
# Data
|
| 132 |
+
|
| 133 |
+
data/doc3D/*
|
| 134 |
+
data/DocUNet/*
|
| 135 |
+
data/UVDoc/*
|
| 136 |
+
data/UVDoc_benchmark/*
|
UVDoc_official/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Tanguy MAGNE
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
UVDoc_official/README.md
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# UVDoc: Neural Grid-based Document Unwarping
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+
This repository contains the code for the "UVDoc: Neural Grid-based Document Unwarping" paper.
|
| 6 |
+
If you are looking for (more information about) the UVDoc dataset, you can find it [here](https://github.com/tanguymagne/UVDoc-Dataset).
|
| 7 |
+
The full UVDoc paper can be found [here](https://igl.ethz.ch/projects/uvdoc/).
|
| 8 |
+
|
| 9 |
+
Three requirements files are provided for the three use cases made available in this repo.
|
| 10 |
+
Each use case is detailed below.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
## Demo
|
| 14 |
+
> **Note** : Requirements
|
| 15 |
+
>
|
| 16 |
+
> Before trying to unwarp a document using our model, you need to install the requirements. To do so, we advise you to create a virtual environment. Then run `pip install -r requirements_demo.txt`.
|
| 17 |
+
|
| 18 |
+
To try our model (available in this repo at `model/best_model.pkl`) on your custom images, run the following:
|
| 19 |
+
```shell
|
| 20 |
+
python demo.py --img-path [PATH/TO/IMAGE]
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
You can also use a model you trained yourself by specifying the path to the model like this:
|
| 24 |
+
```shell
|
| 25 |
+
python demo.py --img-path [PATH/TO/IMAGE] --ckpt-path [PATH/TO/MODEL]
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## Model training
|
| 30 |
+
> **Note** : Requirements
|
| 31 |
+
>
|
| 32 |
+
> Before training a model, you need to install the requirements. To do so, we advise you to create a virtual environment. Then run `pip install -r requirements_train.txt`.
|
| 33 |
+
|
| 34 |
+
To train a model, you first need to get the data:
|
| 35 |
+
- UVDoc dataset can be accessed [here](https://igl.ethz.ch/projects/uvdoc/UVDoc_final.zip).
|
| 36 |
+
- The Doc3D dataset can be downloaded from [here](https://github.com/cvlab-stonybrook/doc3D-dataset). We augmented this dataset with 2D grids and 3D grids that are available [here](https://igl.ethz.ch/projects/uvdoc/Doc3D_grid.zip).
|
| 37 |
+
|
| 38 |
+
Then, unzip the downloaded archive into the data folder. The final structure of the data folder should be as follows:
|
| 39 |
+
```
|
| 40 |
+
data/
|
| 41 |
+
├── doc3D
|
| 42 |
+
│ ├── grid2D
|
| 43 |
+
│ ├── grid3D
|
| 44 |
+
│ ├── bm
|
| 45 |
+
│ └── img
|
| 46 |
+
└── UVDoc
|
| 47 |
+
├── grid2d
|
| 48 |
+
├── grid3d
|
| 49 |
+
├── img
|
| 50 |
+
├── img_geom
|
| 51 |
+
├── metadata_geom
|
| 52 |
+
├── metadata_sample
|
| 53 |
+
├── seg
|
| 54 |
+
├── textures
|
| 55 |
+
├── uvmap
|
| 56 |
+
├── warped_textures
|
| 57 |
+
└── wc
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
Once this is done, run the following:
|
| 61 |
+
```shell
|
| 62 |
+
python train.py
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Several hyperparameters, such as data augmentations, number of epochs, learning rate, or batch size can be tuned. To learn about them, please run the following:
|
| 66 |
+
```shell
|
| 67 |
+
python train.py --help
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
## Evaluation
|
| 72 |
+
> **Note** : Requirements
|
| 73 |
+
>
|
| 74 |
+
> Before evaluating a model, you need to install the requirements. To do so, we advise you to create a virtual environment. Then run `pip install -r requirements_eval.txt`.
|
| 75 |
+
>
|
| 76 |
+
> You will also need to install `matlab.engine`, to allow interfacing matlab with python. To do so, you first need to find the location of your matlab installation (for instance, by running `matlabroot` from within matlab). Then go to `<matlabroot>/extern/engines/python` and run `python setup.py install`. You can open a python prompt and run `import matlab.engine` followed by `eng = matlab.engine.start_matlab()` to see if it was successful.
|
| 77 |
+
>
|
| 78 |
+
> Finally you might need to install `tesseract` via `sudo apt install tesseract-ocr libtesseract-dev`.
|
| 79 |
+
|
| 80 |
+
You can easily evaluate our model or a model you trained yourself using the provided script.
|
| 81 |
+
Our model is available in this repo at `model/best_model.pkl`.
|
| 82 |
+
|
| 83 |
+
### DocUNet benchmark
|
| 84 |
+
To make predictions using a model on the DocUNet benchmark, please first download the DocUNet Benchmark (available [here](https://www3.cs.stonybrook.edu/~cvl/docunet.html)) and place it under data to have the following structure:
|
| 85 |
+
```
|
| 86 |
+
data/
|
| 87 |
+
└── DocUNet
|
| 88 |
+
├── crop
|
| 89 |
+
├── original
|
| 90 |
+
└── scan
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
Then run:
|
| 94 |
+
```shell
|
| 95 |
+
python docUnet_pred.py --ckpt-path [PATH/TO/MODEL]
|
| 96 |
+
```
|
| 97 |
+
This will create a `docunet` folder next to the model, containing the unwarped images.
|
| 98 |
+
|
| 99 |
+
Then to compute the metrics over these predictions, please run the following:
|
| 100 |
+
```shell
|
| 101 |
+
python docUnet_eval.py --pred-path [PATH/TO/UNWARPED]
|
| 102 |
+
```
|
| 103 |
+
### UVDoc benchmark
|
| 104 |
+
To make predictions using a model on the UVDoc benchmark, please first download the UVDoc Benchmark (available [here](https://igl.ethz.ch/projects/uvdoc/)) and place it under data to have the following structure:
|
| 105 |
+
```
|
| 106 |
+
data/
|
| 107 |
+
└── UVDoc_benchmark
|
| 108 |
+
├── grid2d
|
| 109 |
+
├── grid3d
|
| 110 |
+
└── ...
|
| 111 |
+
```
|
| 112 |
+
Then run:
|
| 113 |
+
```shell
|
| 114 |
+
python uvdocBenchmark_pred.py --ckpt-path [PATH/TO/MODEL]
|
| 115 |
+
```
|
| 116 |
+
This will create a `output_uvdoc` folder next to the model, containing the unwarped images.
|
| 117 |
+
|
| 118 |
+
Then to compute the metrics over these predictions, please run the following:
|
| 119 |
+
```shell
|
| 120 |
+
python uvdocBenchmark_eval.py --pred-path [PATH/TO/UNWARPED]
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
#### :exclamation: Erratum
|
| 124 |
+
The MS-SSIM and AD values for the UVDoc benchmark reported in our paper mistakenly were calculated based on only half of the UVDoc benchmark (for our method as well as related works).
|
| 125 |
+
We here report the old and the corrected values on the entire UVDoc benchmark:
|
| 126 |
+
| :white_check_mark: New :white_check_mark: | MS-SSIM | AD |
|
| 127 |
+
|-----------|---------|-------|
|
| 128 |
+
| DewarpNet | 0.589 | 0.193 |
|
| 129 |
+
| DocTr | 0.697 | 0.160 |
|
| 130 |
+
| DDCP | 0.585 | 0.290 |
|
| 131 |
+
| RDGR | 0.610 | 0.280 |
|
| 132 |
+
| DocGeoNet | 0.706 | 0.168 |
|
| 133 |
+
| Ours | 0.785 | 0.119 |
|
| 134 |
+
|
| 135 |
+
| :x: Old :x: | MS-SSIM | AD |
|
| 136 |
+
|-----------|---------|-------|
|
| 137 |
+
| DewarpNet | 0.6 | 0.189 |
|
| 138 |
+
| DocTr | 0.684 | 0.176 |
|
| 139 |
+
| DDCP | 0.591 | 0.334 |
|
| 140 |
+
| RDGR | 0.603 | 0.314 |
|
| 141 |
+
| DocGeoNet | 0.714 | 0.167 |
|
| 142 |
+
| Ours | 0.784 | 0.122 |
|
| 143 |
+
|
| 144 |
+
## Resulting images
|
| 145 |
+
You can download the unwarped images that we used in our paper:
|
| 146 |
+
* [Our results for the DocUNet benchmark](https://igl.ethz.ch/projects/uvdoc/DocUnet_results.zip)
|
| 147 |
+
* [Our results for the UVDoc benchmark](https://igl.ethz.ch/projects/uvdoc/UVDocBenchmark_results.zip)
|
| 148 |
+
* [The results of related work for the UVDoc benchmark](https://igl.ethz.ch/projects/uvdoc/UVDocBenchmark_results_RelatedWorks.zip) (generated using their respective published pretrained models)
|
| 149 |
+
|
| 150 |
+
## Citation
|
| 151 |
+
If you used this code or the UVDoc dataset, please consider citing our work:
|
| 152 |
+
```
|
| 153 |
+
@inproceedings{UVDoc,
|
| 154 |
+
title={{UVDoc}: Neural Grid-based Document Unwarping},
|
| 155 |
+
author={Floor Verhoeven and Tanguy Magne and Olga Sorkine-Hornung},
|
| 156 |
+
booktitle = {SIGGRAPH ASIA, Technical Papers},
|
| 157 |
+
year = {2023},
|
| 158 |
+
url={https://doi.org/10.1145/3610548.3618174}
|
| 159 |
+
}
|
| 160 |
+
```
|
UVDoc_official/__pycache__/data_UVDoc.cpython-310.pyc
ADDED
|
Binary file (6.7 kB). View file
|
|
|
UVDoc_official/__pycache__/data_UVDoc.cpython-313.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
UVDoc_official/__pycache__/data_custom_augmentations.cpython-310.pyc
ADDED
|
Binary file (4.17 kB). View file
|
|
|
UVDoc_official/__pycache__/data_doc3D.cpython-312.pyc
ADDED
|
Binary file (5.48 kB). View file
|
|
|
UVDoc_official/__pycache__/data_utils.cpython-313.pyc
ADDED
|
Binary file (8.02 kB). View file
|
|
|
UVDoc_official/compute_uvdoc_grid3d_stats.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Scan UVDoc-style data_root (grid3d/*.mat) and write min/max per channel for raw grid3d.
|
| 4 |
+
Output JSON is consumed by UVDocDataset via --uvdoc_grid3d_stats.
|
| 5 |
+
|
| 6 |
+
Uses one file per geom name (same as training), so cost scales with unique geometries, not images.
|
| 7 |
+
"""
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
from os.path import join as pjoin
|
| 12 |
+
|
| 13 |
+
import h5py as h5
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main():
|
| 18 |
+
parser = argparse.ArgumentParser(description="Compute grid3d min/max stats for UVDoc normalization.")
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--data_path",
|
| 21 |
+
type=str,
|
| 22 |
+
required=True,
|
| 23 |
+
help="Dataset root containing grid3d/*.mat (HDF5 mat with dataset 'grid3d').",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--out",
|
| 27 |
+
type=str,
|
| 28 |
+
required=True,
|
| 29 |
+
help="Output JSON path (x_max, x_min, y_max, y_min, z_max, z_min).",
|
| 30 |
+
)
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
|
| 33 |
+
grid3d_dir = pjoin(args.data_path, "grid3d")
|
| 34 |
+
if not os.path.isdir(grid3d_dir):
|
| 35 |
+
raise FileNotFoundError(f"Missing grid3d directory: {grid3d_dir}")
|
| 36 |
+
|
| 37 |
+
mats = sorted(f[:-4] for f in os.listdir(grid3d_dir) if f.endswith(".mat"))
|
| 38 |
+
if not mats:
|
| 39 |
+
raise RuntimeError(f"No .mat files under {grid3d_dir}")
|
| 40 |
+
|
| 41 |
+
xmn = ymn = zmn = float("inf")
|
| 42 |
+
xmx = ymx = zmx = float("-inf")
|
| 43 |
+
|
| 44 |
+
for name in mats:
|
| 45 |
+
path = pjoin(grid3d_dir, f"{name}.mat")
|
| 46 |
+
with h5.File(path, "r") as file:
|
| 47 |
+
grid3d = np.array(file["grid3d"][:].T)
|
| 48 |
+
xmn = min(xmn, float(grid3d[:, :, 0].min()))
|
| 49 |
+
xmx = max(xmx, float(grid3d[:, :, 0].max()))
|
| 50 |
+
ymn = min(ymn, float(grid3d[:, :, 1].min()))
|
| 51 |
+
ymx = max(ymx, float(grid3d[:, :, 1].max()))
|
| 52 |
+
zmn = min(zmn, float(grid3d[:, :, 2].min()))
|
| 53 |
+
zmx = max(zmx, float(grid3d[:, :, 2].max()))
|
| 54 |
+
|
| 55 |
+
stats = {
|
| 56 |
+
"x_max": xmx,
|
| 57 |
+
"x_min": xmn,
|
| 58 |
+
"y_max": ymx,
|
| 59 |
+
"y_min": ymn,
|
| 60 |
+
"z_max": zmx,
|
| 61 |
+
"z_min": zmn,
|
| 62 |
+
"num_grid3d_files": len(mats),
|
| 63 |
+
}
|
| 64 |
+
out_dir = os.path.dirname(os.path.abspath(args.out))
|
| 65 |
+
if out_dir:
|
| 66 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 67 |
+
with open(args.out, "w", encoding="utf-8") as f:
|
| 68 |
+
json.dump(stats, f, indent=2)
|
| 69 |
+
print(f"Wrote {args.out} from {len(mats)} grid3d files.")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
main()
|
UVDoc_official/data/readme.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Add doc3D and UVDoc data here
|
UVDoc_official/data_UVDoc.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
from os.path import join as pjoin
|
| 7 |
+
from typing import List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import h5py as h5
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from data_utils import BaseDataset, get_geometric_transform
|
| 15 |
+
from utils import GRID_SIZE, IMG_SIZE, bilinear_unwarping
|
| 16 |
+
|
| 17 |
+
# Default stats from the original UVDoc release (x_max, x_min, y_max, y_min, z_max, z_min).
|
| 18 |
+
DEFAULT_GRID3D_NORMALIZATION = (
|
| 19 |
+
0.11433014,
|
| 20 |
+
-0.12551452,
|
| 21 |
+
0.12401487,
|
| 22 |
+
-0.12401487,
|
| 23 |
+
0.1952378,
|
| 24 |
+
-0.1952378,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_grid3d_stats_json(path: str) -> Tuple[float, float, float, float, float, float]:
|
| 29 |
+
"""Load (x_max, x_min, y_max, y_min, z_max, z_min) from JSON written by compute_uvdoc_grid3d_stats.py."""
|
| 30 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 31 |
+
d = json.load(f)
|
| 32 |
+
keys = ("x_max", "x_min", "y_max", "y_min", "z_max", "z_min")
|
| 33 |
+
missing = [k for k in keys if k not in d]
|
| 34 |
+
if missing:
|
| 35 |
+
raise KeyError(f"grid3d stats JSON missing keys {missing}: {path}")
|
| 36 |
+
return tuple(float(d[k]) for k in keys) # type: ignore[return-value]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _split_samples_by_id(
|
| 40 |
+
ids: List[str],
|
| 41 |
+
split: str,
|
| 42 |
+
val_ratio: float,
|
| 43 |
+
split_seed: int,
|
| 44 |
+
) -> List[str]:
|
| 45 |
+
rng = random.Random(split_seed)
|
| 46 |
+
order = ids[:]
|
| 47 |
+
rng.shuffle(order)
|
| 48 |
+
n_val = max(1, int(round(len(order) * float(val_ratio))))
|
| 49 |
+
n_train = len(order) - n_val
|
| 50 |
+
if split == "train":
|
| 51 |
+
return order[:n_train]
|
| 52 |
+
return order[n_train:]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _split_samples_by_geom(
|
| 56 |
+
ids: List[str],
|
| 57 |
+
dataroot: str,
|
| 58 |
+
split: str,
|
| 59 |
+
val_ratio: float,
|
| 60 |
+
split_seed: int,
|
| 61 |
+
) -> Optional[List[str]]:
|
| 62 |
+
"""
|
| 63 |
+
Split so that no geom_name appears in both train and val.
|
| 64 |
+
Returns None if splitting by geom is impossible (e.g. a single unique geometry).
|
| 65 |
+
"""
|
| 66 |
+
geom_to_samples = {}
|
| 67 |
+
for sid in ids:
|
| 68 |
+
with open(pjoin(dataroot, "metadata_sample", f"{sid}.json"), "r", encoding="utf-8") as f:
|
| 69 |
+
g = json.load(f)["geom_name"]
|
| 70 |
+
geom_to_samples.setdefault(g, []).append(sid)
|
| 71 |
+
|
| 72 |
+
geoms = list(geom_to_samples.keys())
|
| 73 |
+
n_geoms = len(geoms)
|
| 74 |
+
if n_geoms <= 1:
|
| 75 |
+
warnings.warn(
|
| 76 |
+
"UVDocDataset: split_mode=geom but unique geom_name count is <= 1; "
|
| 77 |
+
"falling back to sample-level split.",
|
| 78 |
+
UserWarning,
|
| 79 |
+
stacklevel=3,
|
| 80 |
+
)
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
rng = random.Random(split_seed)
|
| 84 |
+
order = geoms[:]
|
| 85 |
+
rng.shuffle(order)
|
| 86 |
+
n_val_geoms = max(1, int(round(n_geoms * float(val_ratio))))
|
| 87 |
+
if n_val_geoms >= n_geoms:
|
| 88 |
+
n_val_geoms = n_geoms - 1
|
| 89 |
+
|
| 90 |
+
val_geom_set = set(order[-n_val_geoms:])
|
| 91 |
+
|
| 92 |
+
train_samples = []
|
| 93 |
+
val_samples = []
|
| 94 |
+
for g, sids in geom_to_samples.items():
|
| 95 |
+
if g in val_geom_set:
|
| 96 |
+
val_samples.extend(sids)
|
| 97 |
+
else:
|
| 98 |
+
train_samples.extend(sids)
|
| 99 |
+
|
| 100 |
+
train_samples.sort()
|
| 101 |
+
val_samples.sort()
|
| 102 |
+
return train_samples if split == "train" else val_samples
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class UVDocDataset(BaseDataset):
|
| 106 |
+
"""
|
| 107 |
+
Torch dataset class for the UVDoc dataset.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
data_path="./data/UVdoc",
|
| 113 |
+
appearance_augmentation=[],
|
| 114 |
+
geometric_augmentations=[],
|
| 115 |
+
grid_size=GRID_SIZE,
|
| 116 |
+
split=None,
|
| 117 |
+
val_ratio=0.05,
|
| 118 |
+
split_seed=42,
|
| 119 |
+
split_mode="sample",
|
| 120 |
+
grid3d_stats_path: Optional[str] = None,
|
| 121 |
+
deterministic_crop=None,
|
| 122 |
+
max_samples=None,
|
| 123 |
+
overfit=False,
|
| 124 |
+
) -> None:
|
| 125 |
+
super().__init__(
|
| 126 |
+
data_path=data_path,
|
| 127 |
+
appearance_augmentation=appearance_augmentation,
|
| 128 |
+
img_size=IMG_SIZE,
|
| 129 |
+
grid_size=grid_size,
|
| 130 |
+
)
|
| 131 |
+
self.original_grid_size = (89, 61) # size of the captured data
|
| 132 |
+
if grid3d_stats_path:
|
| 133 |
+
self.grid3d_normalization = load_grid3d_stats_json(grid3d_stats_path)
|
| 134 |
+
else:
|
| 135 |
+
self.grid3d_normalization = DEFAULT_GRID3D_NORMALIZATION
|
| 136 |
+
self.geometric_transform = get_geometric_transform(geometric_augmentations, gridsize=self.original_grid_size)
|
| 137 |
+
|
| 138 |
+
ids = sorted([x[:-4] for x in os.listdir(pjoin(self.dataroot, "img")) if x.endswith(".png")])
|
| 139 |
+
if max_samples is not None:
|
| 140 |
+
ids = ids[: int(max_samples)]
|
| 141 |
+
if overfit:
|
| 142 |
+
self.all_samples = ids
|
| 143 |
+
elif split in ("train", "val"):
|
| 144 |
+
if split_mode == "geom":
|
| 145 |
+
assigned = _split_samples_by_geom(ids, self.dataroot, split, val_ratio, split_seed)
|
| 146 |
+
if assigned is None:
|
| 147 |
+
assigned = _split_samples_by_id(ids, split, val_ratio, split_seed)
|
| 148 |
+
self.all_samples = assigned
|
| 149 |
+
elif split_mode == "sample":
|
| 150 |
+
self.all_samples = _split_samples_by_id(ids, split, val_ratio, split_seed)
|
| 151 |
+
else:
|
| 152 |
+
raise ValueError(f"split_mode must be 'sample' or 'geom', got {split_mode!r}")
|
| 153 |
+
else:
|
| 154 |
+
self.all_samples = ids
|
| 155 |
+
|
| 156 |
+
if deterministic_crop is None:
|
| 157 |
+
self.deterministic_crop = split == "val"
|
| 158 |
+
else:
|
| 159 |
+
self.deterministic_crop = bool(deterministic_crop)
|
| 160 |
+
|
| 161 |
+
def __getitem__(self, index):
|
| 162 |
+
# Get all paths
|
| 163 |
+
sample_id = self.all_samples[index]
|
| 164 |
+
with open(pjoin(self.dataroot, "metadata_sample", f"{sample_id}.json"), "r", encoding="utf-8") as f:
|
| 165 |
+
sample_name = json.load(f)["geom_name"]
|
| 166 |
+
img_path = pjoin(self.dataroot, "img", f"{sample_id}.png")
|
| 167 |
+
grid2D_path = pjoin(self.dataroot, "grid2d", f"{sample_name}.mat")
|
| 168 |
+
grid3D_path = pjoin(self.dataroot, "grid3d", f"{sample_name}.mat")
|
| 169 |
+
|
| 170 |
+
# Load 2D grid, 3D grid and image. Normalize 3D grid
|
| 171 |
+
with h5.File(grid2D_path, "r") as file:
|
| 172 |
+
grid2D_ = np.array(file["grid2d"][:].T.transpose(2, 0, 1)) # scale in range of img resolution
|
| 173 |
+
|
| 174 |
+
with h5.File(grid3D_path, "r") as file:
|
| 175 |
+
grid3D = np.array(file["grid3d"][:].T)
|
| 176 |
+
|
| 177 |
+
if self.normalize_3Dgrid: # scale grid3D to [0,1], based on stats computed over the entire dataset
|
| 178 |
+
xmx, xmn, ymx, ymn, zmx, zmn = self.grid3d_normalization
|
| 179 |
+
eps = 1e-12
|
| 180 |
+
for c, cmn, cmx in ((0, xmn, xmx), (1, ymn, ymx), (2, zmn, zmx)):
|
| 181 |
+
denom = cmx - cmn
|
| 182 |
+
if abs(denom) < eps:
|
| 183 |
+
grid3D[:, :, c] = 0.0
|
| 184 |
+
else:
|
| 185 |
+
grid3D[:, :, c] = (grid3D[:, :, c] - cmn) / denom
|
| 186 |
+
grid3D = np.array(grid3D, dtype=np.float32)
|
| 187 |
+
grid3D = torch.from_numpy(grid3D.transpose(2, 0, 1))
|
| 188 |
+
|
| 189 |
+
img_RGB_ = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
|
| 190 |
+
|
| 191 |
+
# Pixel-wise augmentation
|
| 192 |
+
img_RGB_ = self.appearance_transform(image=img_RGB_)["image"]
|
| 193 |
+
|
| 194 |
+
# Geometric Augmentations
|
| 195 |
+
with warnings.catch_warnings():
|
| 196 |
+
warnings.filterwarnings("ignore")
|
| 197 |
+
transformed = self.geometric_transform(
|
| 198 |
+
image=img_RGB_,
|
| 199 |
+
keypoints=grid2D_.transpose(1, 2, 0).reshape(-1, 2),
|
| 200 |
+
)
|
| 201 |
+
img_RGB_ = transformed["image"]
|
| 202 |
+
|
| 203 |
+
grid2D_ = np.array(transformed["keypoints"]).reshape(*self.original_grid_size, 2).transpose(2, 0, 1)
|
| 204 |
+
|
| 205 |
+
flipped = False
|
| 206 |
+
for x in transformed["replay"]["transforms"]:
|
| 207 |
+
if "SafeHorizontalFlip" in x["__class_fullname__"]:
|
| 208 |
+
flipped = x["applied"]
|
| 209 |
+
if flipped:
|
| 210 |
+
grid3D[1] = 1 - grid3D[1]
|
| 211 |
+
grid3D = torch.flip(grid3D, dims=(2,))
|
| 212 |
+
|
| 213 |
+
# Tight crop
|
| 214 |
+
grid2Dtmp = grid2D_
|
| 215 |
+
img_RGB, grid2D = self.crop_tight(img_RGB_, grid2Dtmp, deterministic=self.deterministic_crop)
|
| 216 |
+
|
| 217 |
+
# Subsample grids to desired resolution
|
| 218 |
+
row_sampling_factor = math.ceil(self.original_grid_size[0] / self.grid_size[0])
|
| 219 |
+
col_sampling_factor = math.ceil(self.original_grid_size[1] / self.grid_size[1])
|
| 220 |
+
grid3D = grid3D[:, ::row_sampling_factor, ::col_sampling_factor]
|
| 221 |
+
grid2D = grid2D[:, ::row_sampling_factor, ::col_sampling_factor]
|
| 222 |
+
grid2D = torch.from_numpy(grid2D).float()
|
| 223 |
+
|
| 224 |
+
# Unwarp the image according to grid
|
| 225 |
+
img_RGB_unwarped = bilinear_unwarping(img_RGB.unsqueeze(0), grid2D.unsqueeze(0), self.img_size).squeeze()
|
| 226 |
+
|
| 227 |
+
return (
|
| 228 |
+
img_RGB.float() / 255.0,
|
| 229 |
+
img_RGB_unwarped.float() / 255.0,
|
| 230 |
+
grid2D,
|
| 231 |
+
grid3D,
|
| 232 |
+
)
|
UVDoc_official/data_custom_augmentations.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import albumentations as A
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from utils import GRID_SIZE
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SafeHorizontalFlip(A.HorizontalFlip):
|
| 9 |
+
"""
|
| 10 |
+
Horizontal Flip that changes the order of the keypoints so that the top left one remains in the top left position.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, gridsize=GRID_SIZE, always_apply: bool = False, p: float = 0.5):
|
| 14 |
+
super().__init__(always_apply, p)
|
| 15 |
+
self.gridsize = gridsize
|
| 16 |
+
|
| 17 |
+
def apply_to_keypoints(self, keypoints, **params):
|
| 18 |
+
keypoints = super().apply_to_keypoints(keypoints, **params)
|
| 19 |
+
|
| 20 |
+
keypoints = np.array(keypoints).reshape(*self.gridsize, -1)[:, ::-1, :]
|
| 21 |
+
keypoints = keypoints.reshape(np.product(self.gridsize), -1)
|
| 22 |
+
return keypoints
|
| 23 |
+
|
| 24 |
+
def get_transform_init_args_names(self):
|
| 25 |
+
return ("gridsize",)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SafePerspective(A.Perspective):
|
| 29 |
+
"""
|
| 30 |
+
Perspective augmentation that keeps all keypoints in the image visible.
|
| 31 |
+
Mostly copied from the original Perspective augmentation from Albumentation.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
scale=(0.05, 0.1),
|
| 37 |
+
keep_size=True,
|
| 38 |
+
pad_mode=cv2.BORDER_CONSTANT,
|
| 39 |
+
pad_val=0,
|
| 40 |
+
mask_pad_val=0,
|
| 41 |
+
fit_output=False,
|
| 42 |
+
interpolation=cv2.INTER_LINEAR,
|
| 43 |
+
always_apply=False,
|
| 44 |
+
p=0.5,
|
| 45 |
+
):
|
| 46 |
+
super().__init__(
|
| 47 |
+
scale,
|
| 48 |
+
keep_size,
|
| 49 |
+
pad_mode,
|
| 50 |
+
pad_val,
|
| 51 |
+
mask_pad_val,
|
| 52 |
+
fit_output,
|
| 53 |
+
interpolation,
|
| 54 |
+
always_apply,
|
| 55 |
+
p,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def targets_as_params(self):
|
| 60 |
+
return ["image", "keypoints"]
|
| 61 |
+
|
| 62 |
+
def get_params_dependent_on_targets(self, params):
|
| 63 |
+
h, w = params["image"].shape[:2]
|
| 64 |
+
keypoints = np.array(params["keypoints"])[:, :2] / np.array([w, h])
|
| 65 |
+
left = np.min(keypoints[:, 0])
|
| 66 |
+
right = np.max(keypoints[:, 0])
|
| 67 |
+
top = np.min(keypoints[:, 1])
|
| 68 |
+
bottom = np.max(keypoints[:, 1])
|
| 69 |
+
|
| 70 |
+
points = np.zeros([4, 2])
|
| 71 |
+
# Top Left point
|
| 72 |
+
points[0, 0] = A.random_utils.uniform(0, max(left - 0.01, left / 2))
|
| 73 |
+
points[0, 1] = A.random_utils.uniform(0, max(top - 0.01, top / 2))
|
| 74 |
+
# Top right point
|
| 75 |
+
points[1, 0] = A.random_utils.uniform(min(right + 0.01, (right + 1) / 2), 1)
|
| 76 |
+
points[1, 1] = A.random_utils.uniform(0, max(top - 0.01, top / 2))
|
| 77 |
+
# Bottom Right point
|
| 78 |
+
points[2, 0] = A.random_utils.uniform(min(right + 0.01, (right + 1) / 2), 1)
|
| 79 |
+
points[2, 1] = A.random_utils.uniform(min(bottom + 0.01, (bottom + 1) / 2), 1)
|
| 80 |
+
# Bottom Left point
|
| 81 |
+
points[3, 0] = A.random_utils.uniform(0, max(left - 0.01, left / 2))
|
| 82 |
+
points[3, 1] = A.random_utils.uniform(min(bottom + 0.01, (bottom + 1) / 2), 1)
|
| 83 |
+
|
| 84 |
+
points[:, 0] *= w
|
| 85 |
+
points[:, 1] *= h
|
| 86 |
+
|
| 87 |
+
# Obtain a consistent order of the points and unpack them individually.
|
| 88 |
+
# Warning: don't just do (tl, tr, br, bl) = _order_points(...)
|
| 89 |
+
# here, because the reordered points is used further below.
|
| 90 |
+
points = self._order_points(points)
|
| 91 |
+
tl, tr, br, bl = points
|
| 92 |
+
|
| 93 |
+
# compute the width of the new image, which will be the
|
| 94 |
+
# maximum distance between bottom-right and bottom-left
|
| 95 |
+
# x-coordiates or the top-right and top-left x-coordinates
|
| 96 |
+
min_width = None
|
| 97 |
+
max_width = None
|
| 98 |
+
while min_width is None or min_width < 2:
|
| 99 |
+
width_top = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
|
| 100 |
+
width_bottom = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
|
| 101 |
+
max_width = int(max(width_top, width_bottom))
|
| 102 |
+
min_width = int(min(width_top, width_bottom))
|
| 103 |
+
if min_width < 2:
|
| 104 |
+
step_size = (2 - min_width) / 2
|
| 105 |
+
tl[0] -= step_size
|
| 106 |
+
tr[0] += step_size
|
| 107 |
+
bl[0] -= step_size
|
| 108 |
+
br[0] += step_size
|
| 109 |
+
|
| 110 |
+
# compute the height of the new image, which will be the maximum distance between the top-right
|
| 111 |
+
# and bottom-right y-coordinates or the top-left and bottom-left y-coordinates
|
| 112 |
+
min_height = None
|
| 113 |
+
max_height = None
|
| 114 |
+
while min_height is None or min_height < 2:
|
| 115 |
+
height_right = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
|
| 116 |
+
height_left = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
|
| 117 |
+
max_height = int(max(height_right, height_left))
|
| 118 |
+
min_height = int(min(height_right, height_left))
|
| 119 |
+
if min_height < 2:
|
| 120 |
+
step_size = (2 - min_height) / 2
|
| 121 |
+
tl[1] -= step_size
|
| 122 |
+
tr[1] -= step_size
|
| 123 |
+
bl[1] += step_size
|
| 124 |
+
br[1] += step_size
|
| 125 |
+
|
| 126 |
+
# now that we have the dimensions of the new image, construct
|
| 127 |
+
# the set of destination points to obtain a "birds eye view",
|
| 128 |
+
# (i.e. top-down view) of the image, again specifying points
|
| 129 |
+
# in the top-left, top-right, bottom-right, and bottom-left order
|
| 130 |
+
# do not use width-1 or height-1 here, as for e.g. width=3, height=2
|
| 131 |
+
# the bottom right coordinate is at (3.0, 2.0) and not (2.0, 1.0)
|
| 132 |
+
dst = np.array(
|
| 133 |
+
[[0, 0], [max_width, 0], [max_width, max_height], [0, max_height]],
|
| 134 |
+
dtype=np.float32,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# compute the perspective transform matrix and then apply it
|
| 138 |
+
m = cv2.getPerspectiveTransform(points, dst)
|
| 139 |
+
|
| 140 |
+
if self.fit_output:
|
| 141 |
+
m, max_width, max_height = self._expand_transform(m, (h, w))
|
| 142 |
+
|
| 143 |
+
return {
|
| 144 |
+
"matrix": m,
|
| 145 |
+
"max_height": max_height,
|
| 146 |
+
"max_width": max_width,
|
| 147 |
+
"interpolation": self.interpolation,
|
| 148 |
+
}
|
UVDoc_official/data_doc3D.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os.path import join as pjoin
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import h5py as h5
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from data_utils import BaseDataset
|
| 9 |
+
from utils import GRID_SIZE, IMG_SIZE, bilinear_unwarping
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class doc3DDataset(BaseDataset):
|
| 13 |
+
"""
|
| 14 |
+
Torch dataset class for the Doc3D dataset.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
data_path="./data/doc3D",
|
| 20 |
+
split="train",
|
| 21 |
+
appearance_augmentation=[],
|
| 22 |
+
grid_size=GRID_SIZE,
|
| 23 |
+
):
|
| 24 |
+
super().__init__(
|
| 25 |
+
data_path=data_path,
|
| 26 |
+
appearance_augmentation=appearance_augmentation,
|
| 27 |
+
img_size=IMG_SIZE,
|
| 28 |
+
grid_size=grid_size,
|
| 29 |
+
)
|
| 30 |
+
self.grid3d_normalization = (1.2539363, -1.2442188, 1.2396319, -1.2289206, 0.6436657, -0.67492497)
|
| 31 |
+
|
| 32 |
+
if split == "train":
|
| 33 |
+
path = pjoin(self.dataroot, "traindoc.txt")
|
| 34 |
+
elif split == "val":
|
| 35 |
+
path = pjoin(self.dataroot, "valdoc3D.txt")
|
| 36 |
+
|
| 37 |
+
with open(path, "r") as files:
|
| 38 |
+
file_list = tuple(files)
|
| 39 |
+
self.all_samples = np.array([id_.rstrip() for id_ in file_list], dtype=np.string_)
|
| 40 |
+
|
| 41 |
+
def __getitem__(self, index):
|
| 42 |
+
# Get all paths
|
| 43 |
+
im_name = self.all_samples[index].decode("UTF-8")
|
| 44 |
+
img_path = pjoin(self.dataroot, "img", im_name + ".png")
|
| 45 |
+
grid2D_path = pjoin(self.dataroot, "grid2D", im_name + ".mat")
|
| 46 |
+
grid3D_path = pjoin(self.dataroot, "grid3D", im_name + ".mat")
|
| 47 |
+
bm_path = pjoin(self.dataroot, "bm", im_name + ".mat")
|
| 48 |
+
|
| 49 |
+
# Load 2D grid, 3D grid and image. Normalize 3D grid
|
| 50 |
+
with h5.File(grid2D_path, "r") as file:
|
| 51 |
+
grid2D_ = np.array(file["grid2D"][:].T.transpose(2, 0, 1)) # scale in range of img resolution
|
| 52 |
+
|
| 53 |
+
with h5.File(grid3D_path, "r") as file:
|
| 54 |
+
grid3D = np.array(file["grid3D"][:].T)
|
| 55 |
+
|
| 56 |
+
if self.normalize_3Dgrid: # scale grid3D to [0,1], based on stats computed over the entire dataset
|
| 57 |
+
xmx, xmn, ymx, ymn, zmx, zmn = self.grid3d_normalization
|
| 58 |
+
grid3D[:, :, 0] = (grid3D[:, :, 0] - zmn) / (zmx - zmn)
|
| 59 |
+
grid3D[:, :, 1] = (grid3D[:, :, 1] - ymn) / (ymx - ymn)
|
| 60 |
+
grid3D[:, :, 2] = (grid3D[:, :, 2] - xmn) / (xmx - xmn)
|
| 61 |
+
grid3D = np.array(grid3D, dtype=np.float32)
|
| 62 |
+
grid3D[:, :, 1] = grid3D[:, :, 1][:, ::-1]
|
| 63 |
+
grid3D[:, :, 1] = 1 - grid3D[:, :, 1]
|
| 64 |
+
grid3D = torch.from_numpy(grid3D.transpose(2, 0, 1))
|
| 65 |
+
|
| 66 |
+
img_RGB_ = cv2.cvtColor(cv2.imread(img_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
|
| 67 |
+
|
| 68 |
+
# Pixel-wise augmentation
|
| 69 |
+
img_RGB_ = self.appearance_transform(image=img_RGB_)["image"]
|
| 70 |
+
|
| 71 |
+
# Create unwarped image according to the backward mapping (first load the backward mapping)
|
| 72 |
+
with h5.File(bm_path, "r") as file:
|
| 73 |
+
bm = np.array(file["bm"][:].T.transpose(2, 0, 1))
|
| 74 |
+
bm = ((bm / 448) - 0.5) * 2.0
|
| 75 |
+
bm = torch.from_numpy(bm).float()
|
| 76 |
+
|
| 77 |
+
img_RGB_unwarped = bilinear_unwarping(
|
| 78 |
+
torch.from_numpy(img_RGB_.transpose(2, 0, 1)).float().unsqueeze(0),
|
| 79 |
+
bm.unsqueeze(0),
|
| 80 |
+
self.img_size,
|
| 81 |
+
).squeeze()
|
| 82 |
+
|
| 83 |
+
# Tight crop
|
| 84 |
+
grid2Dtmp = grid2D_
|
| 85 |
+
img_RGB, grid2D = self.crop_tight(img_RGB_, grid2Dtmp)
|
| 86 |
+
|
| 87 |
+
# Convert 2D grid to torch tensor
|
| 88 |
+
grid2D = torch.from_numpy(grid2D).float()
|
| 89 |
+
|
| 90 |
+
return (
|
| 91 |
+
img_RGB.float() / 255.0,
|
| 92 |
+
img_RGB_unwarped.float() / 255.0,
|
| 93 |
+
grid2D,
|
| 94 |
+
grid3D,
|
| 95 |
+
)
|
UVDoc_official/data_mixDataset.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class mixDataset(torch.utils.data.Dataset):
|
| 5 |
+
"""
|
| 6 |
+
Class to use both UVDoc and Doc3D datasets at the same time.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, *datasets):
|
| 10 |
+
self.datasets = datasets
|
| 11 |
+
|
| 12 |
+
def __getitem__(self, ii):
|
| 13 |
+
if len(self.datasets[0]) < len(self.datasets[1]):
|
| 14 |
+
len_shortest = len(self.datasets[0])
|
| 15 |
+
i_shortest = ii % len_shortest
|
| 16 |
+
return self.datasets[0][i_shortest], self.datasets[1][ii]
|
| 17 |
+
else:
|
| 18 |
+
len_shortest = len(self.datasets[1])
|
| 19 |
+
jj = ii % len_shortest
|
| 20 |
+
return self.datasets[0][ii], self.datasets[1][jj]
|
| 21 |
+
|
| 22 |
+
def __len__(self):
|
| 23 |
+
return max(len(d) for d in self.datasets)
|
UVDoc_official/data_utils.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import albumentations as A
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from data_custom_augmentations import SafeHorizontalFlip, SafePerspective
|
| 9 |
+
from utils import GRID_SIZE, IMG_SIZE
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_appearance_transform(transform_types):
|
| 13 |
+
"""
|
| 14 |
+
Returns an albumentation compose augmentation.
|
| 15 |
+
|
| 16 |
+
transform_type is a list containing types of pixel-wise data augmentation to use.
|
| 17 |
+
Possible augmentations are 'shadow', 'blur', 'visual', 'noise', 'color'.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
transforms = []
|
| 21 |
+
if "shadow" in transform_types:
|
| 22 |
+
transforms.append(A.RandomShadow(p=0.1))
|
| 23 |
+
if "blur" in transform_types:
|
| 24 |
+
transforms.append(
|
| 25 |
+
A.OneOf(
|
| 26 |
+
transforms=[
|
| 27 |
+
A.Defocus(p=0.05),
|
| 28 |
+
A.Downscale(p=0.15, interpolation=cv2.INTER_LINEAR),
|
| 29 |
+
A.GaussianBlur(p=0.65),
|
| 30 |
+
A.MedianBlur(p=0.15),
|
| 31 |
+
],
|
| 32 |
+
p=0.75,
|
| 33 |
+
)
|
| 34 |
+
)
|
| 35 |
+
if "visual" in transform_types:
|
| 36 |
+
transforms.append(
|
| 37 |
+
A.OneOf(
|
| 38 |
+
transforms=[
|
| 39 |
+
A.ToSepia(p=0.15),
|
| 40 |
+
A.ToGray(p=0.20),
|
| 41 |
+
A.Equalize(p=0.15),
|
| 42 |
+
A.Sharpen(p=0.20),
|
| 43 |
+
],
|
| 44 |
+
p=0.5,
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
if "noise" in transform_types:
|
| 48 |
+
transforms.append(
|
| 49 |
+
A.OneOf(
|
| 50 |
+
transforms=[
|
| 51 |
+
A.GaussNoise(var_limit=(10.0, 20.0), p=0.70),
|
| 52 |
+
A.ISONoise(intensity=(0.1, 0.25), p=0.30),
|
| 53 |
+
],
|
| 54 |
+
p=0.6,
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
if "color" in transform_types:
|
| 58 |
+
transforms.append(
|
| 59 |
+
A.OneOf(
|
| 60 |
+
transforms=[
|
| 61 |
+
A.ColorJitter(p=0.05),
|
| 62 |
+
A.HueSaturationValue(p=0.10),
|
| 63 |
+
A.RandomBrightnessContrast(brightness_limit=[-0.05, 0.25], p=0.85),
|
| 64 |
+
],
|
| 65 |
+
p=0.95,
|
| 66 |
+
)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
return A.Compose(transforms=transforms)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_geometric_transform(transform_types, gridsize):
|
| 73 |
+
"""
|
| 74 |
+
Returns an albumentation compose augmentation.
|
| 75 |
+
|
| 76 |
+
transform_type is a list containing types of geometric data augmentation to use.
|
| 77 |
+
Possible augmentations are 'rotate', 'flip' and 'perspective'.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
transforms = []
|
| 81 |
+
if "rotate" in transform_types:
|
| 82 |
+
transforms.append(
|
| 83 |
+
A.SafeRotate(
|
| 84 |
+
limit=[-30, 30],
|
| 85 |
+
interpolation=cv2.INTER_LINEAR,
|
| 86 |
+
border_mode=cv2.BORDER_REPLICATE,
|
| 87 |
+
p=0.5,
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
if "flip" in transform_types:
|
| 91 |
+
transforms.append(SafeHorizontalFlip(gridsize=gridsize, p=0.25))
|
| 92 |
+
|
| 93 |
+
if "perspective" in transform_types:
|
| 94 |
+
transforms.append(SafePerspective(p=0.5))
|
| 95 |
+
|
| 96 |
+
return A.ReplayCompose(
|
| 97 |
+
transforms=transforms,
|
| 98 |
+
keypoint_params=A.KeypointParams(format="xy", remove_invisible=False),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def crop_image_tight(img, grid2D, deterministic=False):
|
| 103 |
+
"""
|
| 104 |
+
Crops the image tightly around the keypoints in grid2D.
|
| 105 |
+
This function creates a tight crop around the document in the image.
|
| 106 |
+
"""
|
| 107 |
+
size = img.shape
|
| 108 |
+
|
| 109 |
+
minx = np.floor(np.amin(grid2D[0, :, :])).astype(int)
|
| 110 |
+
maxx = np.ceil(np.amax(grid2D[0, :, :])).astype(int)
|
| 111 |
+
miny = np.floor(np.amin(grid2D[1, :, :])).astype(int)
|
| 112 |
+
maxy = np.ceil(np.amax(grid2D[1, :, :])).astype(int)
|
| 113 |
+
s = 20
|
| 114 |
+
s = min(min(s, minx), miny) # s shouldn't be smaller than actually available natural padding is
|
| 115 |
+
s = min(min(s, size[1] - 1 - maxx), size[0] - 1 - maxy)
|
| 116 |
+
|
| 117 |
+
# Crop the image slightly larger than necessary
|
| 118 |
+
img = img[miny - s : maxy + s, minx - s : maxx + s, :]
|
| 119 |
+
hi = max(s - 5, 1)
|
| 120 |
+
if deterministic:
|
| 121 |
+
cx1 = cy1 = max(hi // 2, 0)
|
| 122 |
+
cx2 = cy2 = max(hi // 2, 0) + 1
|
| 123 |
+
else:
|
| 124 |
+
cx1 = random.randint(0, hi)
|
| 125 |
+
cx2 = random.randint(0, hi) + 1
|
| 126 |
+
cy1 = random.randint(0, hi)
|
| 127 |
+
cy2 = random.randint(0, hi) + 1
|
| 128 |
+
|
| 129 |
+
img = img[cy1:-cy2, cx1:-cx2, :]
|
| 130 |
+
top = miny - s + cy1
|
| 131 |
+
bot = size[0] - maxy - s + cy2
|
| 132 |
+
left = minx - s + cx1
|
| 133 |
+
right = size[1] - maxx - s + cx2
|
| 134 |
+
return img, top, bot, left, right
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class BaseDataset(torch.utils.data.Dataset):
|
| 138 |
+
"""
|
| 139 |
+
Base torch dataset class for all unwarping dataset.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
data_path,
|
| 145 |
+
appearance_augmentation=[],
|
| 146 |
+
img_size=IMG_SIZE,
|
| 147 |
+
grid_size=GRID_SIZE,
|
| 148 |
+
) -> None:
|
| 149 |
+
super().__init__()
|
| 150 |
+
|
| 151 |
+
self.dataroot = data_path
|
| 152 |
+
self.img_size = img_size
|
| 153 |
+
self.grid_size = grid_size
|
| 154 |
+
self.normalize_3Dgrid = True
|
| 155 |
+
|
| 156 |
+
self.appearance_transform = get_appearance_transform(appearance_augmentation)
|
| 157 |
+
|
| 158 |
+
self.all_samples = []
|
| 159 |
+
|
| 160 |
+
def __len__(self):
|
| 161 |
+
return len(self.all_samples)
|
| 162 |
+
|
| 163 |
+
def crop_tight(self, img_RGB, grid2D, deterministic=False):
|
| 164 |
+
# The incoming grid2D array is expressed in pixel coordinates (resolution of img_RGB before crop/resize)
|
| 165 |
+
size = img_RGB.shape
|
| 166 |
+
img, top, bot, left, right = crop_image_tight(img_RGB, grid2D, deterministic=deterministic)
|
| 167 |
+
img = cv2.resize(img, self.img_size)
|
| 168 |
+
img = img.transpose(2, 0, 1)
|
| 169 |
+
img = torch.from_numpy(img).float()
|
| 170 |
+
|
| 171 |
+
grid2D[0, :, :] = (grid2D[0, :, :] - left) / (size[1] - left - right)
|
| 172 |
+
grid2D[1, :, :] = (grid2D[1, :, :] - top) / (size[0] - top - bot)
|
| 173 |
+
grid2D = (grid2D * 2.0) - 1.0
|
| 174 |
+
|
| 175 |
+
return img, grid2D
|
UVDoc_official/demo.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from utils import IMG_SIZE, bilinear_unwarping, load_model
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def unwarp_img(ckpt_path, img_path, img_size):
|
| 12 |
+
"""
|
| 13 |
+
Unwarp a document image using the model from ckpt_path.
|
| 14 |
+
"""
|
| 15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
|
| 17 |
+
# Load model
|
| 18 |
+
model = load_model(ckpt_path)
|
| 19 |
+
model.to(device)
|
| 20 |
+
model.eval()
|
| 21 |
+
|
| 22 |
+
# Load image
|
| 23 |
+
img = cv2.imread(img_path)
|
| 24 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255
|
| 25 |
+
inp = torch.from_numpy(cv2.resize(img, img_size).transpose(2, 0, 1)).unsqueeze(0)
|
| 26 |
+
|
| 27 |
+
# Make prediction
|
| 28 |
+
inp = inp.to(device)
|
| 29 |
+
point_positions2D, _ = model(inp)
|
| 30 |
+
|
| 31 |
+
# Unwarp
|
| 32 |
+
size = img.shape[:2][::-1]
|
| 33 |
+
unwarped = bilinear_unwarping(
|
| 34 |
+
warped_img=torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device),
|
| 35 |
+
point_positions=torch.unsqueeze(point_positions2D[0], dim=0),
|
| 36 |
+
img_size=tuple(size),
|
| 37 |
+
)
|
| 38 |
+
unwarped = (unwarped[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
| 39 |
+
|
| 40 |
+
# Save result
|
| 41 |
+
unwarped_BGR = cv2.cvtColor(unwarped, cv2.COLOR_RGB2BGR)
|
| 42 |
+
cv2.imwrite(os.path.splitext(img_path)[0] + "_unwarp.png", unwarped_BGR)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
parser = argparse.ArgumentParser()
|
| 47 |
+
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--ckpt-path", type=str, default="./model/best_model.pkl", help="Path to the model weights as pkl."
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument("--img-path", type=str, help="Path to the document image to unwarp.")
|
| 52 |
+
|
| 53 |
+
args = parser.parse_args()
|
| 54 |
+
|
| 55 |
+
unwarp_img(args.ckpt_path, args.img_path, IMG_SIZE)
|
UVDoc_official/docUnet_eval.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import multiprocessing as mp
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from utils import get_version
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def visual_metrics_process(queue, docunet_path, preds_path, verbose):
|
| 10 |
+
"""
|
| 11 |
+
Subprocess function that computes visual metrics (MS-SSIM, LD, and AD) based on a matlab script.
|
| 12 |
+
"""
|
| 13 |
+
import matlab.engine
|
| 14 |
+
|
| 15 |
+
eng = matlab.engine.start_matlab()
|
| 16 |
+
eng.cd(r"./eval/eval_code/", nargout=0)
|
| 17 |
+
|
| 18 |
+
mean_ms, mean_ld, mean_ad = eng.evalScript(os.path.join(docunet_path, "scan"), preds_path, verbose, nargout=3)
|
| 19 |
+
queue.put(dict(ms=mean_ms, ld=mean_ld, ad=mean_ad))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def ocr_process(queue, docunet_path, preds_path, crop_type):
|
| 23 |
+
"""
|
| 24 |
+
Subprocess function that computes OCR metrics (CER and ED).
|
| 25 |
+
"""
|
| 26 |
+
from eval.ocr_eval.ocr_eval import OCR_eval_docunet
|
| 27 |
+
|
| 28 |
+
CERmean, EDmean, OCR_dict_results = OCR_eval_docunet(
|
| 29 |
+
os.path.join(docunet_path, "scan"), preds_path, os.path.join(docunet_path, crop_type)
|
| 30 |
+
)
|
| 31 |
+
with open(os.path.join(preds_path, "ocr_res.json"), "w") as f:
|
| 32 |
+
json.dump(OCR_dict_results, f)
|
| 33 |
+
queue.put(dict(cer=CERmean, ed=EDmean))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def compute_metrics(docunet_path, preds_path, crop_type, verbose=False):
|
| 37 |
+
"""
|
| 38 |
+
Compute and save all metrics.
|
| 39 |
+
"""
|
| 40 |
+
if not preds_path.endswith("/"):
|
| 41 |
+
preds_path += "/"
|
| 42 |
+
q = mp.Queue()
|
| 43 |
+
|
| 44 |
+
# Create process to compute MS-SSIM, LD, AD
|
| 45 |
+
p1 = mp.Process(target=visual_metrics_process, args=(q, docunet_path, preds_path, verbose))
|
| 46 |
+
p1.start()
|
| 47 |
+
|
| 48 |
+
# Create process to compute OCR metrics
|
| 49 |
+
p2 = mp.Process(target=ocr_process, args=(q, docunet_path, preds_path, crop_type))
|
| 50 |
+
p2.start()
|
| 51 |
+
|
| 52 |
+
p1.join()
|
| 53 |
+
p2.join()
|
| 54 |
+
|
| 55 |
+
# Get results
|
| 56 |
+
res = {}
|
| 57 |
+
for _ in range(q.qsize()):
|
| 58 |
+
ret = q.get()
|
| 59 |
+
for k, v in ret.items():
|
| 60 |
+
res[k] = v
|
| 61 |
+
|
| 62 |
+
# Print and saves results
|
| 63 |
+
print("--- Results ---")
|
| 64 |
+
print(f" Mean MS-SSIM : {res['ms']}")
|
| 65 |
+
print(f" Mean LD : {res['ld']}")
|
| 66 |
+
print(f" Mean AD : {res['ad']}")
|
| 67 |
+
print(f" Mean CER : {res['cer']}")
|
| 68 |
+
print(f" Mean ED : {res['ed']}")
|
| 69 |
+
|
| 70 |
+
with open(os.path.join(preds_path, "res.txt"), "w") as f:
|
| 71 |
+
f.write(f"Mean MS-SSIM : {res['ms']}\n")
|
| 72 |
+
f.write(f"Mean LD : {res['ld']}\n")
|
| 73 |
+
f.write(f"Mean AD : {res['ad']}\n")
|
| 74 |
+
f.write(f"Mean CER : {res['cer']}\n")
|
| 75 |
+
f.write(f"Mean ED : {res['ed']}\n")
|
| 76 |
+
|
| 77 |
+
model_info_path = os.path.join(preds_path, "model_info.txt")
|
| 78 |
+
if os.path.isfile(model_info_path):
|
| 79 |
+
with open(model_info_path) as modinf_f:
|
| 80 |
+
for x in modinf_f.readlines():
|
| 81 |
+
f.write(x)
|
| 82 |
+
|
| 83 |
+
f.write("\n--- Module Version ---\n")
|
| 84 |
+
for module, version in get_version().items():
|
| 85 |
+
f.write(f"{module:25s}: {version}\n")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
parser = argparse.ArgumentParser()
|
| 90 |
+
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--docunet-path", type=str, default="./data/DocUNet/", help="Path to the DocUNet scans. Needs to be absolute."
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument("--pred-path", type=str, help="Path to the DocUnet predictions. Needs to be absolute.")
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--crop-type",
|
| 97 |
+
type=str,
|
| 98 |
+
default="crop",
|
| 99 |
+
help="The type of cropping to use as input of the model : 'crop' or 'original'",
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument("-v", "--verbose", action="store_true")
|
| 102 |
+
args = parser.parse_args()
|
| 103 |
+
|
| 104 |
+
compute_metrics(
|
| 105 |
+
os.path.abspath(args.docunet_path), os.path.abspath(args.pred_path), args.crop_type, verbose=args.verbose
|
| 106 |
+
)
|
UVDoc_official/docUnet_pred.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import platform
|
| 4 |
+
import re
|
| 5 |
+
import subprocess
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from utils import IMG_SIZE, bilinear_unwarping, load_model
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_processor_name():
|
| 17 |
+
"""
|
| 18 |
+
Returns information about the processor used.
|
| 19 |
+
Taken from https://stackoverflow.com/a/13078519.
|
| 20 |
+
"""
|
| 21 |
+
if platform.system() == "Windows":
|
| 22 |
+
return platform.processor()
|
| 23 |
+
elif platform.system() == "Darwin":
|
| 24 |
+
os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
|
| 25 |
+
command = "sysctl -n machdep.cpu.brand_string"
|
| 26 |
+
return subprocess.check_output(command).strip()
|
| 27 |
+
elif platform.system() == "Linux":
|
| 28 |
+
command = "cat /proc/cpuinfo"
|
| 29 |
+
all_info = subprocess.check_output(command, shell=True).decode().strip()
|
| 30 |
+
for line in all_info.split("\n"):
|
| 31 |
+
if "model name" in line:
|
| 32 |
+
return re.sub(".*model name.*:", "", line, 1)
|
| 33 |
+
return ""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def count_parameters(model):
|
| 37 |
+
"""
|
| 38 |
+
Returns the number of parameters of a model.
|
| 39 |
+
Taken from https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9.
|
| 40 |
+
"""
|
| 41 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class docUnetLoader(torch.utils.data.Dataset):
|
| 45 |
+
"""
|
| 46 |
+
Torch dataset class for the DocUNet benchmark dataset.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
data_path,
|
| 52 |
+
crop="original",
|
| 53 |
+
img_size=(488, 712),
|
| 54 |
+
):
|
| 55 |
+
self.dataroot = data_path
|
| 56 |
+
self.crop = crop
|
| 57 |
+
self.im_list = os.listdir(os.path.join(self.dataroot, self.crop))
|
| 58 |
+
self.img_size = img_size
|
| 59 |
+
|
| 60 |
+
def __len__(self):
|
| 61 |
+
return len(self.im_list)
|
| 62 |
+
|
| 63 |
+
def __getitem__(self, index):
|
| 64 |
+
im_name = self.im_list[index]
|
| 65 |
+
img_path = os.path.join(self.dataroot, self.crop, im_name)
|
| 66 |
+
img_RGB = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
| 67 |
+
img_RGB = torch.from_numpy(cv2.resize(img_RGB, tuple(self.img_size)).transpose(2, 0, 1))
|
| 68 |
+
return img_RGB, im_name
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def infer_docUnet(model, dataloader, device, save_path):
|
| 72 |
+
"""
|
| 73 |
+
Unwarp all images in the DocUNet benchmark and save them.
|
| 74 |
+
Also measure the times it takes to perform this operation.
|
| 75 |
+
"""
|
| 76 |
+
model.eval()
|
| 77 |
+
inference_times = []
|
| 78 |
+
inferenceGPU_times = []
|
| 79 |
+
for img_RGB, im_names in tqdm(dataloader):
|
| 80 |
+
# Inference
|
| 81 |
+
start_toGPU = time.time()
|
| 82 |
+
img_RGB = img_RGB.to(device)
|
| 83 |
+
start_inf = time.time()
|
| 84 |
+
point_positions2D, _ = model(img_RGB)
|
| 85 |
+
end_inf = time.time()
|
| 86 |
+
|
| 87 |
+
# Warped image need to be re-open to get full resolution (downsampled in data loader)
|
| 88 |
+
warped = cv2.imread(os.path.join(dataloader.dataset.dataroot, dataloader.dataset.crop, im_names[0]))
|
| 89 |
+
warped = cv2.cvtColor(warped, cv2.COLOR_BGR2RGB)
|
| 90 |
+
warped = torch.from_numpy(warped.transpose(2, 0, 1) / 255.0).float()
|
| 91 |
+
|
| 92 |
+
# To unwarp using the GT aspect ratio, uncomment following lines and replace
|
| 93 |
+
# `size = warped.shape[:2]` by `size = gt.shape[:2]`
|
| 94 |
+
# gt = cv2.imread(
|
| 95 |
+
# os.path.join(
|
| 96 |
+
# dataloader.dataset.dataroot,
|
| 97 |
+
# "scan",
|
| 98 |
+
# im_names[0].split("_")[0] + ".png",
|
| 99 |
+
# )
|
| 100 |
+
# )
|
| 101 |
+
size = warped.shape[1:][::-1]
|
| 102 |
+
|
| 103 |
+
# Unwarping
|
| 104 |
+
start_unwarp = time.time()
|
| 105 |
+
unwarped = bilinear_unwarping(
|
| 106 |
+
warped_img=torch.unsqueeze(warped, dim=0).to(device),
|
| 107 |
+
point_positions=torch.unsqueeze(point_positions2D[0], dim=0),
|
| 108 |
+
img_size=tuple(size),
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
end_unwarp = time.time()
|
| 112 |
+
unwarped = (unwarped[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
| 113 |
+
unwarped_BGR = cv2.cvtColor(unwarped, cv2.COLOR_RGB2BGR)
|
| 114 |
+
end_toGPU = time.time()
|
| 115 |
+
|
| 116 |
+
cv2.imwrite(
|
| 117 |
+
os.path.join(save_path, im_names[0].split(" ")[0].split(".")[0] + ".png"),
|
| 118 |
+
unwarped_BGR,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
inference_times.append(end_inf - start_inf + end_unwarp - start_unwarp)
|
| 122 |
+
inferenceGPU_times.append(end_inf - start_toGPU + end_toGPU - start_unwarp)
|
| 123 |
+
|
| 124 |
+
# Computes average inference time and the number of parameters of the model
|
| 125 |
+
avg_inference_time = np.mean(inference_times)
|
| 126 |
+
avg_inferenceGPU_time = np.mean(inferenceGPU_times)
|
| 127 |
+
n_params = count_parameters(model)
|
| 128 |
+
return avg_inference_time, avg_inferenceGPU_time, n_params
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def create_results(ckpt_path, docUnet_path, crop, img_size):
|
| 132 |
+
"""
|
| 133 |
+
Create results for the DocUNet benchmark.
|
| 134 |
+
"""
|
| 135 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 136 |
+
|
| 137 |
+
# Load model, create dataset and save directory
|
| 138 |
+
model = load_model(ckpt_path)
|
| 139 |
+
model.to(device)
|
| 140 |
+
|
| 141 |
+
dataset = docUnetLoader(docUnet_path, crop, img_size=img_size)
|
| 142 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False)
|
| 143 |
+
|
| 144 |
+
save_path = os.path.join("/".join(ckpt_path.split("/")[:-1]), "docunet", crop)
|
| 145 |
+
os.makedirs(save_path, exist_ok=False)
|
| 146 |
+
print(f" Results will be saved at {save_path}", flush=True)
|
| 147 |
+
|
| 148 |
+
# Infer results from the model and saves metadata
|
| 149 |
+
inference_time, inferenceGPU_time, n_params = infer_docUnet(model, dataloader, device, save_path)
|
| 150 |
+
with open(os.path.join(save_path, "model_info.txt"), "w") as f:
|
| 151 |
+
f.write("\n---Model and Hardware Information---\n")
|
| 152 |
+
f.write(f"Inference Time : {inference_time:.5f}s\n")
|
| 153 |
+
f.write(f" FPS : {1/inference_time:.1f}\n")
|
| 154 |
+
f.write(f"Inference Time (Include Loading To/From GPU) : {inferenceGPU_time:.5f}s\n")
|
| 155 |
+
f.write(f" FPS : {1/inferenceGPU_time:.1f}\n")
|
| 156 |
+
f.write("Using :\n")
|
| 157 |
+
f.write(f" CPU : {get_processor_name()}\n")
|
| 158 |
+
f.write(f" GPU : {torch.cuda.get_device_name(0)}\n")
|
| 159 |
+
f.write(f"Number of Parameters : {n_params:,}\n")
|
| 160 |
+
return save_path
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
parser = argparse.ArgumentParser()
|
| 165 |
+
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--ckpt-path", type=str, default="./model/best_model.pkl", help="Path to the model weights as pkl."
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument("--docunet-path", type=str, default="./data/DocUNet", help="Path to the docunet benchmark.")
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--crop-type",
|
| 172 |
+
type=str,
|
| 173 |
+
default="crop",
|
| 174 |
+
help="The type of cropping to use as input of the model : 'crop' or 'original'.",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
args = parser.parse_args()
|
| 178 |
+
|
| 179 |
+
create_results(args.ckpt_path, os.path.abspath(args.docunet_path), args.crop_type, IMG_SIZE)
|
UVDoc_official/model.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def conv3x3(in_channels, out_channels, kernel_size, stride=1):
|
| 6 |
+
return nn.Conv2d(
|
| 7 |
+
in_channels,
|
| 8 |
+
out_channels,
|
| 9 |
+
kernel_size=kernel_size,
|
| 10 |
+
stride=stride,
|
| 11 |
+
padding=kernel_size // 2,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def dilated_conv_bn_act(in_channels, out_channels, act_fn, BatchNorm, dilation):
|
| 16 |
+
model = nn.Sequential(
|
| 17 |
+
nn.Conv2d(
|
| 18 |
+
in_channels,
|
| 19 |
+
out_channels,
|
| 20 |
+
bias=False,
|
| 21 |
+
kernel_size=3,
|
| 22 |
+
stride=1,
|
| 23 |
+
padding=dilation,
|
| 24 |
+
dilation=dilation,
|
| 25 |
+
),
|
| 26 |
+
BatchNorm(out_channels),
|
| 27 |
+
act_fn,
|
| 28 |
+
)
|
| 29 |
+
return model
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def dilated_conv(in_channels, out_channels, kernel_size, dilation, stride=1):
|
| 33 |
+
model = nn.Sequential(
|
| 34 |
+
nn.Conv2d(
|
| 35 |
+
in_channels,
|
| 36 |
+
out_channels,
|
| 37 |
+
kernel_size=kernel_size,
|
| 38 |
+
stride=stride,
|
| 39 |
+
padding=dilation * (kernel_size // 2),
|
| 40 |
+
dilation=dilation,
|
| 41 |
+
)
|
| 42 |
+
)
|
| 43 |
+
return model
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ResidualBlockWithDilation(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
in_channels,
|
| 50 |
+
out_channels,
|
| 51 |
+
BatchNorm,
|
| 52 |
+
kernel_size,
|
| 53 |
+
stride=1,
|
| 54 |
+
downsample=None,
|
| 55 |
+
is_activation=True,
|
| 56 |
+
is_top=False,
|
| 57 |
+
):
|
| 58 |
+
super(ResidualBlockWithDilation, self).__init__()
|
| 59 |
+
self.stride = stride
|
| 60 |
+
self.downsample = downsample
|
| 61 |
+
self.is_activation = is_activation
|
| 62 |
+
self.is_top = is_top
|
| 63 |
+
if self.stride != 1 or self.is_top:
|
| 64 |
+
self.conv1 = conv3x3(in_channels, out_channels, kernel_size, self.stride)
|
| 65 |
+
self.conv2 = conv3x3(out_channels, out_channels, kernel_size)
|
| 66 |
+
else:
|
| 67 |
+
self.conv1 = dilated_conv(in_channels, out_channels, kernel_size, dilation=3)
|
| 68 |
+
self.conv2 = dilated_conv(out_channels, out_channels, kernel_size, dilation=3)
|
| 69 |
+
|
| 70 |
+
self.bn1 = BatchNorm(out_channels)
|
| 71 |
+
self.relu = nn.ReLU(inplace=True)
|
| 72 |
+
self.bn2 = BatchNorm(out_channels)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
residual = x
|
| 76 |
+
if self.downsample is not None:
|
| 77 |
+
residual = self.downsample(x)
|
| 78 |
+
|
| 79 |
+
out1 = self.relu(self.bn1(self.conv1(x)))
|
| 80 |
+
out2 = self.bn2(self.conv2(out1))
|
| 81 |
+
|
| 82 |
+
out2 += residual
|
| 83 |
+
out = self.relu(out2)
|
| 84 |
+
return out
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ResnetStraight(nn.Module):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
num_filter,
|
| 91 |
+
map_num,
|
| 92 |
+
BatchNorm,
|
| 93 |
+
block_nums=[3, 4, 6, 3],
|
| 94 |
+
block=ResidualBlockWithDilation,
|
| 95 |
+
kernel_size=5,
|
| 96 |
+
stride=[1, 1, 2, 2],
|
| 97 |
+
):
|
| 98 |
+
super(ResnetStraight, self).__init__()
|
| 99 |
+
self.in_channels = num_filter * map_num[0]
|
| 100 |
+
self.stride = stride
|
| 101 |
+
self.relu = nn.ReLU(inplace=True)
|
| 102 |
+
self.block_nums = block_nums
|
| 103 |
+
self.kernel_size = kernel_size
|
| 104 |
+
|
| 105 |
+
self.layer1 = self.blocklayer(
|
| 106 |
+
block,
|
| 107 |
+
num_filter * map_num[0],
|
| 108 |
+
self.block_nums[0],
|
| 109 |
+
BatchNorm,
|
| 110 |
+
kernel_size=self.kernel_size,
|
| 111 |
+
stride=self.stride[0],
|
| 112 |
+
)
|
| 113 |
+
self.layer2 = self.blocklayer(
|
| 114 |
+
block,
|
| 115 |
+
num_filter * map_num[1],
|
| 116 |
+
self.block_nums[1],
|
| 117 |
+
BatchNorm,
|
| 118 |
+
kernel_size=self.kernel_size,
|
| 119 |
+
stride=self.stride[1],
|
| 120 |
+
)
|
| 121 |
+
self.layer3 = self.blocklayer(
|
| 122 |
+
block,
|
| 123 |
+
num_filter * map_num[2],
|
| 124 |
+
self.block_nums[2],
|
| 125 |
+
BatchNorm,
|
| 126 |
+
kernel_size=self.kernel_size,
|
| 127 |
+
stride=self.stride[2],
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def blocklayer(self, block, out_channels, block_nums, BatchNorm, kernel_size, stride=1):
|
| 131 |
+
downsample = None
|
| 132 |
+
if (stride != 1) or (self.in_channels != out_channels):
|
| 133 |
+
downsample = nn.Sequential(
|
| 134 |
+
conv3x3(
|
| 135 |
+
self.in_channels,
|
| 136 |
+
out_channels,
|
| 137 |
+
kernel_size=kernel_size,
|
| 138 |
+
stride=stride,
|
| 139 |
+
),
|
| 140 |
+
BatchNorm(out_channels),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
layers = []
|
| 144 |
+
layers.append(
|
| 145 |
+
block(
|
| 146 |
+
self.in_channels,
|
| 147 |
+
out_channels,
|
| 148 |
+
BatchNorm,
|
| 149 |
+
kernel_size,
|
| 150 |
+
stride,
|
| 151 |
+
downsample,
|
| 152 |
+
is_top=True,
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
self.in_channels = out_channels
|
| 156 |
+
for i in range(1, block_nums):
|
| 157 |
+
layers.append(
|
| 158 |
+
block(
|
| 159 |
+
out_channels,
|
| 160 |
+
out_channels,
|
| 161 |
+
BatchNorm,
|
| 162 |
+
kernel_size,
|
| 163 |
+
is_activation=True,
|
| 164 |
+
is_top=False,
|
| 165 |
+
)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return nn.Sequential(*layers)
|
| 169 |
+
|
| 170 |
+
def forward(self, x):
|
| 171 |
+
out1 = self.layer1(x)
|
| 172 |
+
out2 = self.layer2(out1)
|
| 173 |
+
out3 = self.layer3(out2)
|
| 174 |
+
return out3
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class UVDocnet(nn.Module):
|
| 178 |
+
def __init__(self, num_filter, kernel_size=5):
|
| 179 |
+
super(UVDocnet, self).__init__()
|
| 180 |
+
self.num_filter = num_filter
|
| 181 |
+
self.in_channels = 3
|
| 182 |
+
self.kernel_size = kernel_size
|
| 183 |
+
self.stride = [1, 2, 2, 2]
|
| 184 |
+
|
| 185 |
+
BatchNorm = nn.BatchNorm2d
|
| 186 |
+
act_fn = nn.ReLU(inplace=True)
|
| 187 |
+
map_num = [1, 2, 4, 8, 16]
|
| 188 |
+
|
| 189 |
+
self.resnet_head = nn.Sequential(
|
| 190 |
+
nn.Conv2d(
|
| 191 |
+
self.in_channels,
|
| 192 |
+
self.num_filter * map_num[0],
|
| 193 |
+
bias=False,
|
| 194 |
+
kernel_size=self.kernel_size,
|
| 195 |
+
stride=2,
|
| 196 |
+
padding=self.kernel_size // 2,
|
| 197 |
+
),
|
| 198 |
+
BatchNorm(self.num_filter * map_num[0]),
|
| 199 |
+
act_fn,
|
| 200 |
+
nn.Conv2d(
|
| 201 |
+
self.num_filter * map_num[0],
|
| 202 |
+
self.num_filter * map_num[0],
|
| 203 |
+
bias=False,
|
| 204 |
+
kernel_size=self.kernel_size,
|
| 205 |
+
stride=2,
|
| 206 |
+
padding=self.kernel_size // 2,
|
| 207 |
+
),
|
| 208 |
+
BatchNorm(self.num_filter * map_num[0]),
|
| 209 |
+
act_fn,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
self.resnet_down = ResnetStraight(
|
| 213 |
+
self.num_filter,
|
| 214 |
+
map_num,
|
| 215 |
+
BatchNorm,
|
| 216 |
+
block_nums=[3, 4, 6, 3],
|
| 217 |
+
block=ResidualBlockWithDilation,
|
| 218 |
+
kernel_size=self.kernel_size,
|
| 219 |
+
stride=self.stride,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
map_num_i = 2
|
| 223 |
+
self.bridge_1 = nn.Sequential(
|
| 224 |
+
dilated_conv_bn_act(
|
| 225 |
+
self.num_filter * map_num[map_num_i],
|
| 226 |
+
self.num_filter * map_num[map_num_i],
|
| 227 |
+
act_fn,
|
| 228 |
+
BatchNorm,
|
| 229 |
+
dilation=1,
|
| 230 |
+
)
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self.bridge_2 = nn.Sequential(
|
| 234 |
+
dilated_conv_bn_act(
|
| 235 |
+
self.num_filter * map_num[map_num_i],
|
| 236 |
+
self.num_filter * map_num[map_num_i],
|
| 237 |
+
act_fn,
|
| 238 |
+
BatchNorm,
|
| 239 |
+
dilation=2,
|
| 240 |
+
)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
self.bridge_3 = nn.Sequential(
|
| 244 |
+
dilated_conv_bn_act(
|
| 245 |
+
self.num_filter * map_num[map_num_i],
|
| 246 |
+
self.num_filter * map_num[map_num_i],
|
| 247 |
+
act_fn,
|
| 248 |
+
BatchNorm,
|
| 249 |
+
dilation=5,
|
| 250 |
+
)
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self.bridge_4 = nn.Sequential(
|
| 254 |
+
*[
|
| 255 |
+
dilated_conv_bn_act(
|
| 256 |
+
self.num_filter * map_num[map_num_i],
|
| 257 |
+
self.num_filter * map_num[map_num_i],
|
| 258 |
+
act_fn,
|
| 259 |
+
BatchNorm,
|
| 260 |
+
dilation=d,
|
| 261 |
+
)
|
| 262 |
+
for d in [8, 3, 2]
|
| 263 |
+
]
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
self.bridge_5 = nn.Sequential(
|
| 267 |
+
*[
|
| 268 |
+
dilated_conv_bn_act(
|
| 269 |
+
self.num_filter * map_num[map_num_i],
|
| 270 |
+
self.num_filter * map_num[map_num_i],
|
| 271 |
+
act_fn,
|
| 272 |
+
BatchNorm,
|
| 273 |
+
dilation=d,
|
| 274 |
+
)
|
| 275 |
+
for d in [12, 7, 4]
|
| 276 |
+
]
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
self.bridge_6 = nn.Sequential(
|
| 280 |
+
*[
|
| 281 |
+
dilated_conv_bn_act(
|
| 282 |
+
self.num_filter * map_num[map_num_i],
|
| 283 |
+
self.num_filter * map_num[map_num_i],
|
| 284 |
+
act_fn,
|
| 285 |
+
BatchNorm,
|
| 286 |
+
dilation=d,
|
| 287 |
+
)
|
| 288 |
+
for d in [18, 12, 6]
|
| 289 |
+
]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
self.bridge_concat = nn.Sequential(
|
| 293 |
+
nn.Conv2d(
|
| 294 |
+
self.num_filter * map_num[map_num_i] * 6,
|
| 295 |
+
self.num_filter * map_num[2],
|
| 296 |
+
bias=False,
|
| 297 |
+
kernel_size=1,
|
| 298 |
+
stride=1,
|
| 299 |
+
padding=0,
|
| 300 |
+
),
|
| 301 |
+
BatchNorm(self.num_filter * map_num[2]),
|
| 302 |
+
act_fn,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.out_point_positions2D = nn.Sequential(
|
| 306 |
+
nn.Conv2d(
|
| 307 |
+
self.num_filter * map_num[2],
|
| 308 |
+
self.num_filter * map_num[0],
|
| 309 |
+
bias=False,
|
| 310 |
+
kernel_size=self.kernel_size,
|
| 311 |
+
stride=1,
|
| 312 |
+
padding=self.kernel_size // 2,
|
| 313 |
+
padding_mode="reflect",
|
| 314 |
+
),
|
| 315 |
+
BatchNorm(self.num_filter * map_num[0]),
|
| 316 |
+
nn.PReLU(),
|
| 317 |
+
nn.Conv2d(
|
| 318 |
+
self.num_filter * map_num[0],
|
| 319 |
+
2,
|
| 320 |
+
kernel_size=self.kernel_size,
|
| 321 |
+
stride=1,
|
| 322 |
+
padding=self.kernel_size // 2,
|
| 323 |
+
padding_mode="reflect",
|
| 324 |
+
),
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
self.out_point_positions3D = nn.Sequential(
|
| 328 |
+
nn.Conv2d(
|
| 329 |
+
self.num_filter * map_num[2],
|
| 330 |
+
self.num_filter * map_num[0],
|
| 331 |
+
bias=False,
|
| 332 |
+
kernel_size=self.kernel_size,
|
| 333 |
+
stride=1,
|
| 334 |
+
padding=self.kernel_size // 2,
|
| 335 |
+
padding_mode="reflect",
|
| 336 |
+
),
|
| 337 |
+
BatchNorm(self.num_filter * map_num[0]),
|
| 338 |
+
nn.PReLU(),
|
| 339 |
+
nn.Conv2d(
|
| 340 |
+
self.num_filter * map_num[0],
|
| 341 |
+
3,
|
| 342 |
+
kernel_size=self.kernel_size,
|
| 343 |
+
stride=1,
|
| 344 |
+
padding=self.kernel_size // 2,
|
| 345 |
+
padding_mode="reflect",
|
| 346 |
+
),
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
self._initialize_weights()
|
| 350 |
+
|
| 351 |
+
def _initialize_weights(self):
|
| 352 |
+
for m in self.modules():
|
| 353 |
+
if isinstance(m, nn.Conv2d):
|
| 354 |
+
nn.init.xavier_normal_(m.weight, gain=0.2)
|
| 355 |
+
if isinstance(m, nn.ConvTranspose2d):
|
| 356 |
+
assert m.kernel_size[0] == m.kernel_size[1]
|
| 357 |
+
nn.init.xavier_normal_(m.weight, gain=0.2)
|
| 358 |
+
|
| 359 |
+
def forward(self, x):
|
| 360 |
+
resnet_head = self.resnet_head(x)
|
| 361 |
+
resnet_down = self.resnet_down(resnet_head)
|
| 362 |
+
bridge_1 = self.bridge_1(resnet_down)
|
| 363 |
+
bridge_2 = self.bridge_2(resnet_down)
|
| 364 |
+
bridge_3 = self.bridge_3(resnet_down)
|
| 365 |
+
bridge_4 = self.bridge_4(resnet_down)
|
| 366 |
+
bridge_5 = self.bridge_5(resnet_down)
|
| 367 |
+
bridge_6 = self.bridge_6(resnet_down)
|
| 368 |
+
bridge_concat = torch.cat([bridge_1, bridge_2, bridge_3, bridge_4, bridge_5, bridge_6], dim=1)
|
| 369 |
+
bridge = self.bridge_concat(bridge_concat)
|
| 370 |
+
|
| 371 |
+
out_point_positions2D = self.out_point_positions2D(bridge)
|
| 372 |
+
out_point_positions3D = self.out_point_positions3D(bridge)
|
| 373 |
+
|
| 374 |
+
return out_point_positions2D, out_point_positions3D
|
UVDoc_official/requirements_demo.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.23.4
|
| 2 |
+
opencv_python_headless==4.7.0.68
|
| 3 |
+
torch==1.13.0
|
UVDoc_official/requirements_eval.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hdf5storage==0.1.18
|
| 2 |
+
jiwer==3.0.1
|
| 3 |
+
numpy==1.23.4
|
| 4 |
+
opencv_python_headless==4.7.0.68
|
| 5 |
+
Pillow==9.4.0
|
| 6 |
+
pytesseract==0.3.10
|
| 7 |
+
python_Levenshtein
|
| 8 |
+
scikit-image
|
| 9 |
+
torch==1.13.0
|
| 10 |
+
tqdm==4.64.1
|
UVDoc_official/requirements_train.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
albumentations==1.3.0
|
| 2 |
+
h5py==3.7.0
|
| 3 |
+
numpy==1.23.4
|
| 4 |
+
opencv_python_headless==4.7.0.68
|
| 5 |
+
torch==1.13.0
|
UVDoc_official/run_official_overfit_train_infer.sh
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Official-style UVDoc overfit: training uses the same UVDocDataset settings as
|
| 3 |
+
# verify_ckpt_val_pipeline.py (deterministic crop, no appearance/geo aug).
|
| 4 |
+
# Inference MUST use verify_ckpt_val_pipeline.py — not demo.py (full-image resize).
|
| 5 |
+
#
|
| 6 |
+
# Default OVERFIT_N=8 with BATCH_SIZE=8 so each batch is full (BatchNorm behaves
|
| 7 |
+
# closer to normal training). Set OVERFIT_N=1 for a single sample if you accept BN quirks.
|
| 8 |
+
#
|
| 9 |
+
# Env (optional):
|
| 10 |
+
# PYTHON, UV_DOC_ROOT, LOGDIR, OUT_DIR, DEVICE, NUM_WORKERS
|
| 11 |
+
# OVERFIT_N (default 8), BATCH_SIZE (default 8)
|
| 12 |
+
# SKIP_PREPROCESS_CHECK=1, SKIP_TRAIN=1, CKPT=/path/to/ep_*_best_model.pkl
|
| 13 |
+
# N_EPOCHS (default 10), N_EPOCHS_DECAY (default 10)
|
| 14 |
+
set -euo pipefail
|
| 15 |
+
|
| 16 |
+
OFFICIAL_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 17 |
+
UV_ROOT="$(cd "${OFFICIAL_ROOT}/.." && pwd)"
|
| 18 |
+
PY="${PYTHON:-python3}"
|
| 19 |
+
UV="${UV_DOC_ROOT:-${UV_ROOT}/UVDoc_final}"
|
| 20 |
+
LOGDIR="${LOGDIR:-${UV_ROOT}/log_official_overfit_train_infer}"
|
| 21 |
+
OUT_DIR="${OUT_DIR:-${LOGDIR}/verify_infer}"
|
| 22 |
+
DEVICE="${DEVICE:-cuda:0}"
|
| 23 |
+
NUM_WORKERS="${NUM_WORKERS:-4}"
|
| 24 |
+
OVERFIT_N="${OVERFIT_N:-8}"
|
| 25 |
+
BATCH_SIZE="${BATCH_SIZE:-8}"
|
| 26 |
+
N_EPOCHS="${N_EPOCHS:-10}"
|
| 27 |
+
N_EPOCHS_DECAY="${N_EPOCHS_DECAY:-10}"
|
| 28 |
+
|
| 29 |
+
cd "${OFFICIAL_ROOT}"
|
| 30 |
+
|
| 31 |
+
if [[ ! -d "${UV}/img" ]]; then
|
| 32 |
+
echo "ERROR: UVDoc data not found: ${UV}/img" >&2
|
| 33 |
+
exit 1
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
echo "== UVDoc root: ${UV}"
|
| 37 |
+
echo "== Log dir: ${LOGDIR}"
|
| 38 |
+
echo "== Overfit N: ${OVERFIT_N} batch: ${BATCH_SIZE}"
|
| 39 |
+
|
| 40 |
+
if [[ "${SKIP_PREPROCESS_CHECK:-0}" != "1" ]]; then
|
| 41 |
+
echo "== (1) Preprocess: train vs verify_ckpt dataset tensors"
|
| 42 |
+
"${PY}" verify_uvdoc_train_infer_preprocess.py \
|
| 43 |
+
--data_path_UVDoc "${UV}" \
|
| 44 |
+
--overfit_n "${OVERFIT_N}" \
|
| 45 |
+
--check_dataloader \
|
| 46 |
+
--batch_size "${BATCH_SIZE}" \
|
| 47 |
+
--num_workers 0
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
if [[ "${SKIP_TRAIN:-0}" != "1" ]]; then
|
| 51 |
+
echo "== (2) Train (official defaults: lr=2e-4, alpha=beta=5, gamma=1, ep_gamma_start=10)"
|
| 52 |
+
mkdir -p "${LOGDIR}"
|
| 53 |
+
"${PY}" train.py \
|
| 54 |
+
--data_to_use uvdoc \
|
| 55 |
+
--data_path_UVDoc "${UV}" \
|
| 56 |
+
--overfit_n "${OVERFIT_N}" \
|
| 57 |
+
--batch_size "${BATCH_SIZE}" \
|
| 58 |
+
--n_epochs "${N_EPOCHS}" \
|
| 59 |
+
--n_epochs_decay "${N_EPOCHS_DECAY}" \
|
| 60 |
+
--lr 0.0002 \
|
| 61 |
+
--alpha_w 5.0 \
|
| 62 |
+
--beta_w 5.0 \
|
| 63 |
+
--gamma_w 1.0 \
|
| 64 |
+
--ep_gamma_start 10 \
|
| 65 |
+
--num_workers "${NUM_WORKERS}" \
|
| 66 |
+
--device "${DEVICE}" \
|
| 67 |
+
--log_eval_mse_train \
|
| 68 |
+
--logdir "${LOGDIR}"
|
| 69 |
+
fi
|
| 70 |
+
|
| 71 |
+
if [[ -n "${CKPT:-}" ]]; then
|
| 72 |
+
:
|
| 73 |
+
else
|
| 74 |
+
EXP_DIR="$(ls -td "${LOGDIR}"/params* 2>/dev/null | head -1 || true)"
|
| 75 |
+
if [[ -z "${EXP_DIR}" ]]; then
|
| 76 |
+
echo "ERROR: No params* under ${LOGDIR}; set CKPT=... or run training." >&2
|
| 77 |
+
exit 1
|
| 78 |
+
fi
|
| 79 |
+
CKPT="$(ls -t "${EXP_DIR}"/ep_*_best_model.pkl 2>/dev/null | head -1 || true)"
|
| 80 |
+
if [[ -z "${CKPT}" ]]; then
|
| 81 |
+
echo "ERROR: No ep_*_best_model.pkl under ${EXP_DIR}" >&2
|
| 82 |
+
exit 1
|
| 83 |
+
fi
|
| 84 |
+
fi
|
| 85 |
+
|
| 86 |
+
echo "== (3) Infer (same UVDocDataset kwargs as train val): ${CKPT}"
|
| 87 |
+
rm -rf "${OUT_DIR}"
|
| 88 |
+
mkdir -p "${OUT_DIR}"
|
| 89 |
+
"${PY}" verify_ckpt_val_pipeline.py \
|
| 90 |
+
--ckpt "${CKPT}" \
|
| 91 |
+
--data_path_UVDoc "${UV}" \
|
| 92 |
+
--overfit_n "${OVERFIT_N}" \
|
| 93 |
+
--out_dir "${OUT_DIR}" \
|
| 94 |
+
--max_save_images "${OVERFIT_N}" \
|
| 95 |
+
--device "${DEVICE}"
|
| 96 |
+
|
| 97 |
+
echo "== Done"
|
| 98 |
+
echo " Metrics: ${OUT_DIR}/metrics.txt"
|
| 99 |
+
echo " Images: ${OUT_DIR}/*.png"
|
| 100 |
+
echo "Compare mean_mse in metrics.txt to train log Val MSE and train_mse_eval (eval mode)."
|
| 101 |
+
cat "${OUT_DIR}/metrics.txt"
|
UVDoc_official/train.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import gc
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
except ImportError:
|
| 11 |
+
tqdm = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
import data_UVDoc
|
| 15 |
+
import model
|
| 16 |
+
import utils
|
| 17 |
+
from data_mixDataset import mixDataset
|
| 18 |
+
|
| 19 |
+
train_mse = 0.0
|
| 20 |
+
losscount = 0
|
| 21 |
+
gamma_w = 0.0
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def setup_data(args):
|
| 25 |
+
"""
|
| 26 |
+
Returns train and validation dataloader.
|
| 27 |
+
"""
|
| 28 |
+
UVDoc = data_UVDoc.UVDocDataset
|
| 29 |
+
traindata = "train"
|
| 30 |
+
valdata = "val"
|
| 31 |
+
|
| 32 |
+
if args.data_to_use == "uvdoc":
|
| 33 |
+
if getattr(args, "overfit_n", 0) and int(args.overfit_n) > 0:
|
| 34 |
+
# Same preprocessing as val / verify_ckpt_val_pipeline.py (deterministic crop, no aug).
|
| 35 |
+
t_UVDoc_data = UVDoc(
|
| 36 |
+
data_path=args.data_path_UVDoc,
|
| 37 |
+
appearance_augmentation=[],
|
| 38 |
+
geometric_augmentations=[],
|
| 39 |
+
overfit=True,
|
| 40 |
+
max_samples=int(args.overfit_n),
|
| 41 |
+
deterministic_crop=True,
|
| 42 |
+
grid3d_stats_path=args.uvdoc_grid3d_stats,
|
| 43 |
+
)
|
| 44 |
+
v_UVDoc_data = UVDoc(
|
| 45 |
+
data_path=args.data_path_UVDoc,
|
| 46 |
+
appearance_augmentation=[],
|
| 47 |
+
geometric_augmentations=[],
|
| 48 |
+
overfit=True,
|
| 49 |
+
max_samples=int(args.overfit_n),
|
| 50 |
+
deterministic_crop=True,
|
| 51 |
+
grid3d_stats_path=args.uvdoc_grid3d_stats,
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
t_UVDoc_data = UVDoc(
|
| 55 |
+
data_path=args.data_path_UVDoc,
|
| 56 |
+
appearance_augmentation=args.appearance_augmentation,
|
| 57 |
+
geometric_augmentations=args.geometric_augmentationsUVDoc,
|
| 58 |
+
split="train",
|
| 59 |
+
val_ratio=args.uvdoc_val_ratio,
|
| 60 |
+
split_seed=args.uvdoc_split_seed,
|
| 61 |
+
split_mode=args.uvdoc_split_mode,
|
| 62 |
+
grid3d_stats_path=args.uvdoc_grid3d_stats,
|
| 63 |
+
)
|
| 64 |
+
v_UVDoc_data = UVDoc(
|
| 65 |
+
data_path=args.data_path_UVDoc,
|
| 66 |
+
appearance_augmentation=[],
|
| 67 |
+
geometric_augmentations=[],
|
| 68 |
+
split="val",
|
| 69 |
+
val_ratio=args.uvdoc_val_ratio,
|
| 70 |
+
split_seed=args.uvdoc_split_seed,
|
| 71 |
+
split_mode=args.uvdoc_split_mode,
|
| 72 |
+
grid3d_stats_path=args.uvdoc_grid3d_stats,
|
| 73 |
+
)
|
| 74 |
+
trainloader = torch.utils.data.DataLoader(
|
| 75 |
+
t_UVDoc_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True
|
| 76 |
+
)
|
| 77 |
+
valloader = torch.utils.data.DataLoader(
|
| 78 |
+
v_UVDoc_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True
|
| 79 |
+
)
|
| 80 |
+
return trainloader, valloader
|
| 81 |
+
|
| 82 |
+
import data_doc3D
|
| 83 |
+
|
| 84 |
+
doc3D = data_doc3D.doc3DDataset
|
| 85 |
+
|
| 86 |
+
# Training data
|
| 87 |
+
t_doc3D_data = doc3D(
|
| 88 |
+
data_path=args.data_path_doc3D,
|
| 89 |
+
split=traindata,
|
| 90 |
+
appearance_augmentation=args.appearance_augmentation,
|
| 91 |
+
)
|
| 92 |
+
t_UVDoc_data = UVDoc(
|
| 93 |
+
data_path=args.data_path_UVDoc,
|
| 94 |
+
appearance_augmentation=args.appearance_augmentation,
|
| 95 |
+
geometric_augmentations=args.geometric_augmentationsUVDoc,
|
| 96 |
+
grid3d_stats_path=args.uvdoc_grid3d_stats,
|
| 97 |
+
)
|
| 98 |
+
t_mix_data = mixDataset(t_doc3D_data, t_UVDoc_data)
|
| 99 |
+
if args.data_to_use == "both":
|
| 100 |
+
trainloader = torch.utils.data.DataLoader(
|
| 101 |
+
t_mix_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True
|
| 102 |
+
)
|
| 103 |
+
elif args.data_to_use == "doc3d":
|
| 104 |
+
trainloader = torch.utils.data.DataLoader(
|
| 105 |
+
t_doc3D_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
raise ValueError(f"data_to_use should be either doc3d, both, or uvdoc, provided {args.data_to_use}.")
|
| 109 |
+
|
| 110 |
+
# Validation data (doc3D only) — matches upstream UVDoc repo
|
| 111 |
+
v_doc3D_data = doc3D(data_path=args.data_path_doc3D, split=valdata, appearance_augmentation=[])
|
| 112 |
+
valloader = torch.utils.data.DataLoader(
|
| 113 |
+
v_doc3D_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
return trainloader, valloader
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_scheduler(optimizer, args, epoch_start):
|
| 120 |
+
"""Return a learning rate scheduler
|
| 121 |
+
Parameters:
|
| 122 |
+
optimizer -- the optimizer of the network
|
| 123 |
+
args -- stores all the experiment flags
|
| 124 |
+
epoch_start -- the epoch number we started/continued from
|
| 125 |
+
We keep the same learning rate for the first <args.n_epochs> epochs
|
| 126 |
+
and linearly decay the rate to zero over the next <args.n_epochs_decay> epochs.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def lambda_rule(epoch):
|
| 130 |
+
lr_l = 1.0 - max(0, epoch + epoch_start - args.n_epochs) / float(args.n_epochs_decay + 1)
|
| 131 |
+
return lr_l
|
| 132 |
+
|
| 133 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
| 134 |
+
|
| 135 |
+
return scheduler
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def update_learning_rate(scheduler, optimizer):
|
| 139 |
+
"""Update learning rates; called at the end of every epoch"""
|
| 140 |
+
old_lr = optimizer.param_groups[0]["lr"]
|
| 141 |
+
scheduler.step()
|
| 142 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 143 |
+
print("learning rate update from %.7f -> %.7f" % (old_lr, lr))
|
| 144 |
+
return lr
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def write_log_file(log_file_name, loss, epoch, lrate, phase):
|
| 148 |
+
with open(log_file_name, "a") as f:
|
| 149 |
+
f.write("\n{} LRate: {} Epoch: {} MSE: {:.5f} ".format(phase, lrate, epoch, loss))
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def write_train_loss_detail_log(log_file_name, epoch, lrate, avg_net, avg_g2, avg_g3, avg_rec, gamma_w):
|
| 153 |
+
"""Append per-epoch mean L1 losses (same definition as netLoss) to the experiment log."""
|
| 154 |
+
with open(log_file_name, "a") as f:
|
| 155 |
+
f.write(
|
| 156 |
+
"\nTrainLoss LRate: {} Epoch: {} net: {:.5f} L1_g2d: {:.5f} L1_g3d: {:.5f} L1_recon: {:.5f} gamma_w: {:.5f} ".format(
|
| 157 |
+
lrate, epoch, avg_net, avg_g2, avg_g3, avg_rec, gamma_w
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def main_worker(args):
|
| 163 |
+
# setup training data
|
| 164 |
+
trainloader, valloader = setup_data(args)
|
| 165 |
+
|
| 166 |
+
device = torch.device(args.device)
|
| 167 |
+
UVDocnet = model.UVDocnet(num_filter=32, kernel_size=5)
|
| 168 |
+
UVDocnet.to(device)
|
| 169 |
+
|
| 170 |
+
# define loss functions
|
| 171 |
+
criterionL1 = torch.nn.L1Loss()
|
| 172 |
+
criterionMSE = torch.nn.MSELoss()
|
| 173 |
+
|
| 174 |
+
# initialize optimizers
|
| 175 |
+
optimizer = torch.optim.Adam(UVDocnet.parameters(), lr=args.lr, betas=(0.9, 0.999))
|
| 176 |
+
|
| 177 |
+
global gamma_w
|
| 178 |
+
epoch_start = 0
|
| 179 |
+
|
| 180 |
+
if args.resume is not None:
|
| 181 |
+
if os.path.isfile(args.resume):
|
| 182 |
+
print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
|
| 183 |
+
checkpoint = torch.load(args.resume)
|
| 184 |
+
|
| 185 |
+
UVDocnet.load_state_dict(checkpoint["model_state"])
|
| 186 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
| 187 |
+
print("Loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint["epoch"]))
|
| 188 |
+
epoch_start = checkpoint["epoch"]
|
| 189 |
+
if epoch_start >= args.ep_gamma_start:
|
| 190 |
+
gamma_w = args.gamma_w
|
| 191 |
+
else:
|
| 192 |
+
print("No checkpoint found at '{}'".format(args.resume))
|
| 193 |
+
|
| 194 |
+
# initialize learning rate schedulers
|
| 195 |
+
scheduler = get_scheduler(optimizer, args, epoch_start)
|
| 196 |
+
|
| 197 |
+
# Log file:
|
| 198 |
+
if not os.path.exists(args.logdir):
|
| 199 |
+
os.makedirs(args.logdir)
|
| 200 |
+
|
| 201 |
+
experiment_name = (
|
| 202 |
+
"params"
|
| 203 |
+
+ str(args.batch_size)
|
| 204 |
+
+ "_lr="
|
| 205 |
+
+ str(args.lr)
|
| 206 |
+
+ "_nepochs"
|
| 207 |
+
+ str(args.n_epochs)
|
| 208 |
+
+ "_nepochsdecay"
|
| 209 |
+
+ str(args.n_epochs_decay)
|
| 210 |
+
+ "_alpha"
|
| 211 |
+
+ str(args.alpha_w)
|
| 212 |
+
+ "_beta"
|
| 213 |
+
+ str(args.beta_w)
|
| 214 |
+
+ "_gamma="
|
| 215 |
+
+ str(args.gamma_w)
|
| 216 |
+
+ "_gammastartep"
|
| 217 |
+
+ str(args.ep_gamma_start)
|
| 218 |
+
+ "_data"
|
| 219 |
+
+ args.data_to_use
|
| 220 |
+
)
|
| 221 |
+
if args.resume:
|
| 222 |
+
experiment_name = "RESUME" + experiment_name
|
| 223 |
+
|
| 224 |
+
log_file_name = os.path.join(args.logdir, experiment_name + ".txt")
|
| 225 |
+
if os.path.isfile(log_file_name):
|
| 226 |
+
log_file = open(log_file_name, "a")
|
| 227 |
+
else:
|
| 228 |
+
log_file = open(log_file_name, "w+")
|
| 229 |
+
|
| 230 |
+
log_file.write("\n--------------- " + experiment_name + " ---------------\n")
|
| 231 |
+
log_file.close()
|
| 232 |
+
|
| 233 |
+
exp_log_dir = os.path.join(args.logdir, experiment_name, "")
|
| 234 |
+
if not os.path.exists(exp_log_dir):
|
| 235 |
+
os.makedirs(exp_log_dir)
|
| 236 |
+
|
| 237 |
+
global losscount
|
| 238 |
+
global train_mse
|
| 239 |
+
|
| 240 |
+
# Run training
|
| 241 |
+
best_val_mse = float("inf")
|
| 242 |
+
for epoch in range(epoch_start, args.n_epochs + args.n_epochs_decay + 1):
|
| 243 |
+
print(f"\n----- Epoch {epoch} -----")
|
| 244 |
+
if epoch >= args.ep_gamma_start:
|
| 245 |
+
gamma_w = args.gamma_w
|
| 246 |
+
print("epoch ", epoch, "gamma_w is now", gamma_w)
|
| 247 |
+
|
| 248 |
+
train_mse = 0.0
|
| 249 |
+
losscount = 0
|
| 250 |
+
sum_l1_g2 = 0.0
|
| 251 |
+
sum_l1_g3 = 0.0
|
| 252 |
+
sum_l1_rec = 0.0
|
| 253 |
+
sum_net = 0.0
|
| 254 |
+
|
| 255 |
+
# Train
|
| 256 |
+
UVDocnet.train()
|
| 257 |
+
|
| 258 |
+
train_iter = trainloader
|
| 259 |
+
if tqdm is not None and not args.no_tqdm:
|
| 260 |
+
train_iter = tqdm(
|
| 261 |
+
trainloader,
|
| 262 |
+
desc=f"train epoch {epoch}",
|
| 263 |
+
dynamic_ncols=True,
|
| 264 |
+
mininterval=2.0,
|
| 265 |
+
file=sys.stdout,
|
| 266 |
+
)
|
| 267 |
+
for batch in train_iter:
|
| 268 |
+
if args.data_to_use == "both":
|
| 269 |
+
(
|
| 270 |
+
imgs_doc3D_,
|
| 271 |
+
imgs_unwarped_doc3D_,
|
| 272 |
+
grid2D_doc3D_,
|
| 273 |
+
grid3D_doc3D_,
|
| 274 |
+
) = batch[0]
|
| 275 |
+
(
|
| 276 |
+
imgs_UVDoc_,
|
| 277 |
+
imgs_unwarped_UVDoc_,
|
| 278 |
+
grid2D_UVDoc_,
|
| 279 |
+
grid3D_UVDoc_,
|
| 280 |
+
) = batch[1]
|
| 281 |
+
elif args.data_to_use == "uvdoc":
|
| 282 |
+
(
|
| 283 |
+
imgs_UVDoc_,
|
| 284 |
+
imgs_unwarped_UVDoc_,
|
| 285 |
+
grid2D_UVDoc_,
|
| 286 |
+
grid3D_UVDoc_,
|
| 287 |
+
) = batch
|
| 288 |
+
elif args.data_to_use == "doc3d":
|
| 289 |
+
(
|
| 290 |
+
imgs_doc3D_,
|
| 291 |
+
imgs_unwarped_doc3D_,
|
| 292 |
+
grid2D_doc3D_,
|
| 293 |
+
grid3D_doc3D_,
|
| 294 |
+
) = batch
|
| 295 |
+
|
| 296 |
+
# Train Doc3D step (official default; skipped for uvdoc-only)
|
| 297 |
+
if args.data_to_use in ("both", "doc3d"):
|
| 298 |
+
imgs_doc3D = imgs_doc3D_.to(device, non_blocking=True)
|
| 299 |
+
unwarped_GT_doc3D = imgs_unwarped_doc3D_.to(device, non_blocking=True)
|
| 300 |
+
grid2D_GT_doc3D = grid2D_doc3D_.to(device, non_blocking=True)
|
| 301 |
+
grid3D_GT_doc3D = grid3D_doc3D_.to(device, non_blocking=True)
|
| 302 |
+
|
| 303 |
+
grid2D_pred_doc3D, grid3D_pred_doc3D = UVDocnet(imgs_doc3D)
|
| 304 |
+
unwarped_pred_doc3D = utils.bilinear_unwarping(imgs_doc3D, grid2D_pred_doc3D, utils.IMG_SIZE)
|
| 305 |
+
|
| 306 |
+
optimizer.zero_grad(set_to_none=True)
|
| 307 |
+
|
| 308 |
+
recon_loss = criterionL1(unwarped_pred_doc3D, unwarped_GT_doc3D)
|
| 309 |
+
loss_grid2D = criterionL1(grid2D_pred_doc3D, grid2D_GT_doc3D)
|
| 310 |
+
loss_grid3D = criterionL1(grid3D_pred_doc3D, grid3D_GT_doc3D)
|
| 311 |
+
|
| 312 |
+
netLoss = args.alpha_w * loss_grid2D + args.beta_w * loss_grid3D + gamma_w * recon_loss
|
| 313 |
+
sum_l1_g2 += float(loss_grid2D.detach())
|
| 314 |
+
sum_l1_g3 += float(loss_grid3D.detach())
|
| 315 |
+
sum_l1_rec += float(recon_loss.detach())
|
| 316 |
+
sum_net += float(netLoss.detach())
|
| 317 |
+
netLoss.backward()
|
| 318 |
+
optimizer.step()
|
| 319 |
+
|
| 320 |
+
tmp_mse = criterionMSE(unwarped_pred_doc3D, unwarped_GT_doc3D)
|
| 321 |
+
train_mse += float(tmp_mse)
|
| 322 |
+
losscount += 1
|
| 323 |
+
|
| 324 |
+
# Train UVDoc step
|
| 325 |
+
if args.data_to_use in ("both", "uvdoc"):
|
| 326 |
+
imgs_UVDoc = imgs_UVDoc_.to(device, non_blocking=True)
|
| 327 |
+
unwarped_GT_UVDoc = imgs_unwarped_UVDoc_.to(device, non_blocking=True)
|
| 328 |
+
grid2D_GT_UVDoc = grid2D_UVDoc_.to(device, non_blocking=True)
|
| 329 |
+
grid3D_GT_UVDoc = grid3D_UVDoc_.to(device, non_blocking=True)
|
| 330 |
+
|
| 331 |
+
grid2D_pred_UVDoc, grid3D_pred_UVDoc = UVDocnet(imgs_UVDoc)
|
| 332 |
+
unwarped_pred_UVDoc = utils.bilinear_unwarping(imgs_UVDoc, grid2D_pred_UVDoc, utils.IMG_SIZE)
|
| 333 |
+
|
| 334 |
+
optimizer.zero_grad(set_to_none=True)
|
| 335 |
+
|
| 336 |
+
recon_loss = criterionL1(unwarped_pred_UVDoc, unwarped_GT_UVDoc)
|
| 337 |
+
loss_grid2D = criterionL1(grid2D_pred_UVDoc, grid2D_GT_UVDoc)
|
| 338 |
+
loss_grid3D = criterionL1(grid3D_pred_UVDoc, grid3D_GT_UVDoc)
|
| 339 |
+
|
| 340 |
+
netLoss = args.alpha_w * loss_grid2D + args.beta_w * loss_grid3D + gamma_w * recon_loss
|
| 341 |
+
sum_l1_g2 += float(loss_grid2D.detach())
|
| 342 |
+
sum_l1_g3 += float(loss_grid3D.detach())
|
| 343 |
+
sum_l1_rec += float(recon_loss.detach())
|
| 344 |
+
sum_net += float(netLoss.detach())
|
| 345 |
+
netLoss.backward()
|
| 346 |
+
optimizer.step()
|
| 347 |
+
|
| 348 |
+
tmp_mse = criterionMSE(unwarped_pred_UVDoc, unwarped_GT_UVDoc)
|
| 349 |
+
train_mse += float(tmp_mse)
|
| 350 |
+
losscount += 1
|
| 351 |
+
gc.collect()
|
| 352 |
+
|
| 353 |
+
train_mse = train_mse / max(1, losscount)
|
| 354 |
+
curr_lr = update_learning_rate(scheduler, optimizer)
|
| 355 |
+
write_log_file(log_file_name, train_mse, epoch + 1, curr_lr, "Train")
|
| 356 |
+
if losscount > 0:
|
| 357 |
+
avg_g2 = sum_l1_g2 / losscount
|
| 358 |
+
avg_g3 = sum_l1_g3 / losscount
|
| 359 |
+
avg_rec = sum_l1_rec / losscount
|
| 360 |
+
avg_net = sum_net / losscount
|
| 361 |
+
write_train_loss_detail_log(
|
| 362 |
+
log_file_name, epoch + 1, curr_lr, avg_net, avg_g2, avg_g3, avg_rec, gamma_w
|
| 363 |
+
)
|
| 364 |
+
print(
|
| 365 |
+
f"Epoch {epoch} train L1 | net={avg_net:.5f} g2d={avg_g2:.5f} g3d={avg_g3:.5f} recon={avg_rec:.5f} gamma_w={gamma_w}",
|
| 366 |
+
flush=True,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Evaluation
|
| 370 |
+
train_mse_eval = None
|
| 371 |
+
UVDocnet.eval()
|
| 372 |
+
|
| 373 |
+
with torch.no_grad():
|
| 374 |
+
mse_loss_val = 0.0
|
| 375 |
+
val_iter = valloader
|
| 376 |
+
if tqdm is not None and not args.no_tqdm:
|
| 377 |
+
val_iter = tqdm(
|
| 378 |
+
valloader,
|
| 379 |
+
desc=f"val epoch {epoch}",
|
| 380 |
+
dynamic_ncols=True,
|
| 381 |
+
mininterval=2.0,
|
| 382 |
+
file=sys.stdout,
|
| 383 |
+
leave=False,
|
| 384 |
+
)
|
| 385 |
+
for imgs_val_, imgs_unwarped_val_, _, _ in val_iter:
|
| 386 |
+
imgs_val = imgs_val_.to(device)
|
| 387 |
+
unwarped_GT_val = imgs_unwarped_val_.to(device)
|
| 388 |
+
|
| 389 |
+
grid2D_pred_val, grid3D_pred_val = UVDocnet(imgs_val)
|
| 390 |
+
unwarped_pred_val = utils.bilinear_unwarping(imgs_val, grid2D_pred_val, utils.IMG_SIZE)
|
| 391 |
+
|
| 392 |
+
loss_img_val = criterionMSE(unwarped_pred_val, unwarped_GT_val)
|
| 393 |
+
mse_loss_val += float(loss_img_val)
|
| 394 |
+
|
| 395 |
+
denom = max(1, len(valloader))
|
| 396 |
+
val_mse = mse_loss_val / denom
|
| 397 |
+
write_log_file(log_file_name, val_mse, epoch + 1, curr_lr, "Val")
|
| 398 |
+
|
| 399 |
+
if getattr(args, "log_eval_mse_train", False) and args.data_to_use == "uvdoc":
|
| 400 |
+
mse_tr = 0.0
|
| 401 |
+
denom_tr = max(1, len(trainloader))
|
| 402 |
+
for imgs_tr_, imgs_uw_tr_, _, _ in trainloader:
|
| 403 |
+
imgs_tr = imgs_tr_.to(device)
|
| 404 |
+
uw_gt_tr = imgs_uw_tr_.to(device)
|
| 405 |
+
g2_tr, _ = UVDocnet(imgs_tr)
|
| 406 |
+
pred_tr = utils.bilinear_unwarping(imgs_tr, g2_tr, utils.IMG_SIZE)
|
| 407 |
+
mse_tr += float(criterionMSE(pred_tr, uw_gt_tr))
|
| 408 |
+
train_mse_eval = mse_tr / denom_tr
|
| 409 |
+
write_log_file(log_file_name, train_mse_eval, epoch + 1, curr_lr, "TrainEval")
|
| 410 |
+
|
| 411 |
+
if train_mse_eval is not None:
|
| 412 |
+
print(
|
| 413 |
+
f"Epoch {epoch} summary | train_mse={train_mse:.5f} val_mse={val_mse:.5f} "
|
| 414 |
+
f"train_mse_eval={train_mse_eval:.5f} lr={curr_lr:.7f}",
|
| 415 |
+
flush=True,
|
| 416 |
+
)
|
| 417 |
+
else:
|
| 418 |
+
print(
|
| 419 |
+
f"Epoch {epoch} summary | train_mse={train_mse:.5f} val_mse={val_mse:.5f} lr={curr_lr:.7f}",
|
| 420 |
+
flush=True,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# save best models
|
| 424 |
+
if val_mse < best_val_mse or epoch == args.n_epochs + args.n_epochs_decay:
|
| 425 |
+
best_val_mse = val_mse
|
| 426 |
+
state = {
|
| 427 |
+
"epoch": epoch + 1,
|
| 428 |
+
"model_state": UVDocnet.state_dict(),
|
| 429 |
+
"optimizer_state": optimizer.state_dict(),
|
| 430 |
+
}
|
| 431 |
+
model_path = exp_log_dir + f"ep_{epoch + 1}_{val_mse:.5f}_{train_mse:.5f}_best_model.pkl"
|
| 432 |
+
torch.save(state, model_path)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
if __name__ == "__main__":
|
| 436 |
+
parser = argparse.ArgumentParser(description="Hyperparams")
|
| 437 |
+
|
| 438 |
+
parser.add_argument(
|
| 439 |
+
"--data_path_doc3D", nargs="?", type=str, default="./data/doc3D/", help="Data path to load Doc3D data."
|
| 440 |
+
)
|
| 441 |
+
parser.add_argument(
|
| 442 |
+
"--data_path_UVDoc", nargs="?", type=str, default="./data/UVDoc/", help="Data path to load UVDoc data."
|
| 443 |
+
)
|
| 444 |
+
parser.add_argument(
|
| 445 |
+
"--device", type=str, default="cuda:0", help="Torch device, e.g. cuda:0 or cpu."
|
| 446 |
+
)
|
| 447 |
+
parser.add_argument(
|
| 448 |
+
"--uvdoc_val_ratio",
|
| 449 |
+
type=float,
|
| 450 |
+
default=0.05,
|
| 451 |
+
help="Hold-out ratio for UVDoc train/val when data_to_use=uvdoc.",
|
| 452 |
+
)
|
| 453 |
+
parser.add_argument(
|
| 454 |
+
"--uvdoc_split_seed",
|
| 455 |
+
type=int,
|
| 456 |
+
default=42,
|
| 457 |
+
help="Random seed for UVDoc train/val split when data_to_use=uvdoc.",
|
| 458 |
+
)
|
| 459 |
+
parser.add_argument(
|
| 460 |
+
"--uvdoc_split_mode",
|
| 461 |
+
type=str,
|
| 462 |
+
default="sample",
|
| 463 |
+
choices=["sample", "geom"],
|
| 464 |
+
help="UVDoc train/val split: random by image id (sample) or by geom_name so no geometry appears in both splits (geom).",
|
| 465 |
+
)
|
| 466 |
+
parser.add_argument(
|
| 467 |
+
"--uvdoc_grid3d_stats",
|
| 468 |
+
type=str,
|
| 469 |
+
default=None,
|
| 470 |
+
help="Optional JSON from compute_uvdoc_grid3d_stats.py; overrides built-in grid3d min/max normalization.",
|
| 471 |
+
)
|
| 472 |
+
parser.add_argument(
|
| 473 |
+
"--overfit_n",
|
| 474 |
+
type=int,
|
| 475 |
+
default=0,
|
| 476 |
+
help="If >0 (with data_to_use=uvdoc), use only the first N sorted samples for BOTH train and val to sanity-check fitting.",
|
| 477 |
+
)
|
| 478 |
+
parser.add_argument(
|
| 479 |
+
"--data_to_use",
|
| 480 |
+
type=str,
|
| 481 |
+
default="both",
|
| 482 |
+
choices=["both", "doc3d", "uvdoc"],
|
| 483 |
+
help="Dataset: both (Doc3D+UVDoc, official), doc3d only, or uvdoc (UVDoc only; extension for local training without Doc3D).",
|
| 484 |
+
)
|
| 485 |
+
parser.add_argument("--batch_size", nargs="?", type=int, default=8, help="Batch size.")
|
| 486 |
+
parser.add_argument(
|
| 487 |
+
"--n_epochs",
|
| 488 |
+
nargs="?",
|
| 489 |
+
type=int,
|
| 490 |
+
default=10,
|
| 491 |
+
help="Number of epochs with initial (constant) learning rate.",
|
| 492 |
+
)
|
| 493 |
+
parser.add_argument(
|
| 494 |
+
"--n_epochs_decay",
|
| 495 |
+
nargs="?",
|
| 496 |
+
type=int,
|
| 497 |
+
default=10,
|
| 498 |
+
help="Number of epochs to linearly decay learning rate to zero.",
|
| 499 |
+
)
|
| 500 |
+
parser.add_argument("--lr", nargs="?", type=float, default=0.0002, help="Initial learning rate.")
|
| 501 |
+
parser.add_argument("--alpha_w", nargs="?", type=float, default=5.0, help="Weight for the 2D grid L1 loss.")
|
| 502 |
+
parser.add_argument("--beta_w", nargs="?", type=float, default=5.0, help="Weight for the 3D grid L1 loss.")
|
| 503 |
+
parser.add_argument(
|
| 504 |
+
"--gamma_w", nargs="?", type=float, default=1.0, help="Weight for the image reconstruction loss."
|
| 505 |
+
)
|
| 506 |
+
parser.add_argument(
|
| 507 |
+
"--ep_gamma_start",
|
| 508 |
+
nargs="?",
|
| 509 |
+
type=int,
|
| 510 |
+
default=10,
|
| 511 |
+
help="Epoch from which to start using image reconstruction loss.",
|
| 512 |
+
)
|
| 513 |
+
parser.add_argument(
|
| 514 |
+
"--resume",
|
| 515 |
+
nargs="?",
|
| 516 |
+
type=str,
|
| 517 |
+
default=None,
|
| 518 |
+
help="Path to previous saved model to restart from.",
|
| 519 |
+
)
|
| 520 |
+
parser.add_argument("--logdir", nargs="?", type=str, default="./log/default", help="Path to store the logs.")
|
| 521 |
+
parser.add_argument(
|
| 522 |
+
"-a",
|
| 523 |
+
"--appearance_augmentation",
|
| 524 |
+
nargs="*",
|
| 525 |
+
type=str,
|
| 526 |
+
default=["visual", "noise", "color"],
|
| 527 |
+
choices=["shadow", "blur", "visual", "noise", "color"],
|
| 528 |
+
help="Appearance augmentations to use.",
|
| 529 |
+
)
|
| 530 |
+
parser.add_argument(
|
| 531 |
+
"-gUVDoc",
|
| 532 |
+
"--geometric_augmentationsUVDoc",
|
| 533 |
+
nargs="*",
|
| 534 |
+
type=str,
|
| 535 |
+
default=["rotate"],
|
| 536 |
+
choices=["rotate", "flip", "perspective"],
|
| 537 |
+
help="Geometric augmentations to use for the UVDoc dataset.",
|
| 538 |
+
)
|
| 539 |
+
parser.add_argument("--num_workers", type=int, default=8, help="Number of workers to use for the dataloaders.")
|
| 540 |
+
parser.add_argument(
|
| 541 |
+
"--no_tqdm",
|
| 542 |
+
action="store_true",
|
| 543 |
+
help="Disable tqdm progress bars (use plain loops + epoch summary prints).",
|
| 544 |
+
)
|
| 545 |
+
parser.add_argument(
|
| 546 |
+
"--log_eval_mse_train",
|
| 547 |
+
action="store_true",
|
| 548 |
+
help="After val, also MSE(unwarp) in eval() on UVDoc train loader; matches verify_ckpt_val_pipeline.py (not train-mode train_mse).",
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
args = parser.parse_args()
|
| 552 |
+
main_worker(args)
|
UVDoc_official/utils.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from model import UVDocnet
|
| 7 |
+
|
| 8 |
+
IMG_SIZE = [488, 712]
|
| 9 |
+
GRID_SIZE = [45, 31]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_model(ckpt_path):
|
| 13 |
+
"""
|
| 14 |
+
Load UVDocnet model.
|
| 15 |
+
"""
|
| 16 |
+
model = UVDocnet(num_filter=32, kernel_size=5)
|
| 17 |
+
ckpt = torch.load(ckpt_path)
|
| 18 |
+
model.load_state_dict(ckpt["model_state"])
|
| 19 |
+
return model
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_version():
|
| 23 |
+
"""
|
| 24 |
+
Returns the version of the various packages used for evaluation.
|
| 25 |
+
"""
|
| 26 |
+
import pytesseract
|
| 27 |
+
|
| 28 |
+
return {
|
| 29 |
+
"tesseract": str(pytesseract.get_tesseract_version()),
|
| 30 |
+
"pyesseract": os.popen("pip list | grep pytesseract").read().split()[-1],
|
| 31 |
+
"Levenshtein": os.popen("pip list | grep Levenshtein").read().split()[-1],
|
| 32 |
+
"jiwer": os.popen("pip list | grep jiwer").read().split()[-1],
|
| 33 |
+
"matlabengineforpython": os.popen("pip list | grep matlab").read().split()[-1],
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def bilinear_unwarping(warped_img, point_positions, img_size):
|
| 38 |
+
"""
|
| 39 |
+
Utility function that unwarps an image.
|
| 40 |
+
Unwarp warped_img based on the 2D grid point_positions with a size img_size.
|
| 41 |
+
Args:
|
| 42 |
+
warped_img : torch.Tensor of shape BxCxHxW (dtype float)
|
| 43 |
+
point_positions: torch.Tensor of shape Bx2xGhxGw (dtype float)
|
| 44 |
+
img_size: tuple of int [w, h]
|
| 45 |
+
"""
|
| 46 |
+
upsampled_grid = F.interpolate(
|
| 47 |
+
point_positions, size=(img_size[1], img_size[0]), mode="bilinear", align_corners=True
|
| 48 |
+
)
|
| 49 |
+
unwarped_img = F.grid_sample(warped_img, upsampled_grid.transpose(1, 2).transpose(2, 3), align_corners=True)
|
| 50 |
+
|
| 51 |
+
return unwarped_img
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def bilinear_unwarping_from_numpy(warped_img, point_positions, img_size):
|
| 55 |
+
"""
|
| 56 |
+
Utility function that unwarps an image.
|
| 57 |
+
Unwarp warped_img based on the 2D grid point_positions with a size img_size.
|
| 58 |
+
Accept numpy arrays as input.
|
| 59 |
+
"""
|
| 60 |
+
warped_img = torch.unsqueeze(torch.from_numpy(warped_img.transpose(2, 0, 1)).float(), dim=0)
|
| 61 |
+
point_positions = torch.unsqueeze(torch.from_numpy(point_positions.transpose(2, 0, 1)).float(), dim=0)
|
| 62 |
+
|
| 63 |
+
unwarped_img = bilinear_unwarping(warped_img, point_positions, img_size)
|
| 64 |
+
|
| 65 |
+
unwarped_img = unwarped_img[0].numpy().transpose(1, 2, 0)
|
| 66 |
+
return unwarped_img
|
UVDoc_official/uvdocBenchmark_eval.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import multiprocessing as mp
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from utils import get_version
|
| 7 |
+
|
| 8 |
+
N_LINES = 25
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def visual_metrics_process(queue, uvdoc_path, preds_path, verbose):
|
| 12 |
+
"""
|
| 13 |
+
Subprocess function that computes visual metrics (MS-SSIM, LD, and AD) based on a matlab script.
|
| 14 |
+
"""
|
| 15 |
+
import matlab.engine
|
| 16 |
+
|
| 17 |
+
eng = matlab.engine.start_matlab()
|
| 18 |
+
eng.cd(r"./eval/eval_code/", nargout=0)
|
| 19 |
+
|
| 20 |
+
mean_ms, mean_ad = eng.evalScriptUVDoc(uvdoc_path, preds_path, verbose, nargout=2)
|
| 21 |
+
queue.put(dict(ms=mean_ms, ad=mean_ad))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def ocr_process(queue, uvdoc_path, preds_path):
|
| 25 |
+
"""
|
| 26 |
+
Subprocess function that computes OCR metrics (CER and ED).
|
| 27 |
+
"""
|
| 28 |
+
from eval.ocr_eval.ocr_eval import OCR_eval_UVDoc
|
| 29 |
+
|
| 30 |
+
CERmean, EDmean, OCR_dict_results = OCR_eval_UVDoc(uvdoc_path, preds_path)
|
| 31 |
+
with open(os.path.join(preds_path, "ocr_res.json"), "w") as f:
|
| 32 |
+
json.dump(OCR_dict_results, f)
|
| 33 |
+
queue.put(dict(cer=CERmean, ed=EDmean))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def new_line_metric_process(queue, uvdoc_path, preds_path, n_lines):
|
| 37 |
+
"""
|
| 38 |
+
Subprocess function that computes the new line metrics on the UVDoc benchmark.
|
| 39 |
+
"""
|
| 40 |
+
from uvdocBenchmark_metric import compute_line_metric
|
| 41 |
+
|
| 42 |
+
hor_metric, ver_metric = compute_line_metric(uvdoc_path, preds_path, n_lines)
|
| 43 |
+
queue.put(dict(hor_line=hor_metric, ver_line=ver_metric))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def compute_metrics(uvdoc_path, pred_path, pred_type, verbose=False):
|
| 47 |
+
"""
|
| 48 |
+
Compute and save all metrics.
|
| 49 |
+
"""
|
| 50 |
+
if not pred_path.endswith("/"):
|
| 51 |
+
pred_path += "/"
|
| 52 |
+
q = mp.Queue()
|
| 53 |
+
|
| 54 |
+
# Create process to compute MS-SSIM, LD, AD
|
| 55 |
+
p1 = mp.Process(
|
| 56 |
+
target=visual_metrics_process,
|
| 57 |
+
args=(q, os.path.join(uvdoc_path, "texture_sample"), os.path.join(pred_path, pred_type), verbose),
|
| 58 |
+
)
|
| 59 |
+
p1.start()
|
| 60 |
+
|
| 61 |
+
# Create process to compute new line metrics
|
| 62 |
+
p2 = mp.Process(
|
| 63 |
+
target=new_line_metric_process,
|
| 64 |
+
args=(q, uvdoc_path, os.path.join(pred_path, "bm"), N_LINES),
|
| 65 |
+
)
|
| 66 |
+
p2.start()
|
| 67 |
+
|
| 68 |
+
# Create process to compute OCR metrics
|
| 69 |
+
p3 = mp.Process(
|
| 70 |
+
target=ocr_process, args=(q, os.path.join(uvdoc_path, "texture_sample"), os.path.join(pred_path, pred_type))
|
| 71 |
+
)
|
| 72 |
+
p3.start()
|
| 73 |
+
|
| 74 |
+
p1.join()
|
| 75 |
+
p2.join()
|
| 76 |
+
p3.join()
|
| 77 |
+
|
| 78 |
+
# Get results
|
| 79 |
+
res = {}
|
| 80 |
+
for _ in range(q.qsize()):
|
| 81 |
+
ret = q.get()
|
| 82 |
+
for k, v in ret.items():
|
| 83 |
+
res[k] = v
|
| 84 |
+
|
| 85 |
+
# Print and saves results
|
| 86 |
+
print("--- Results ---")
|
| 87 |
+
print(f" Mean MS-SSIM : {res['ms']}")
|
| 88 |
+
print(f" Mean AD : {res['ad']}")
|
| 89 |
+
print(f" Mean CER : {res['cer']}")
|
| 90 |
+
print(f" Mean ED : {res['ed']}")
|
| 91 |
+
print(f" Hor Line : {res['hor_line']}")
|
| 92 |
+
print(f" Ver Line : {res['ver_line']}")
|
| 93 |
+
|
| 94 |
+
with open(os.path.join(pred_path, pred_type, "resUVDoc.txt"), "w") as f:
|
| 95 |
+
f.write(f"Mean MS-SSIM : {res['ms']}\n")
|
| 96 |
+
f.write(f"Mean AD : {res['ad']}\n")
|
| 97 |
+
f.write(f"Mean CER : {res['cer']}\n")
|
| 98 |
+
f.write(f"Mean ED : {res['ed']}\n")
|
| 99 |
+
f.write(f"Hor Line : {res['hor_line']}\n")
|
| 100 |
+
f.write(f"Ver Line : {res['ver_line']}\n")
|
| 101 |
+
|
| 102 |
+
f.write("\n--- Module Version ---\n")
|
| 103 |
+
for module, version in get_version().items():
|
| 104 |
+
f.write(f"{module:25s}: {version}\n")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
parser = argparse.ArgumentParser()
|
| 109 |
+
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--uvdoc-path", type=str, default="./data/UVDoc_benchmark/", help="Path to the uvdoc benchmark dataset"
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument("--pred-path", type=str, help="Path to the UVDoc benchmark predictions. Need to be absolute.")
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--pred-type",
|
| 116 |
+
type=str,
|
| 117 |
+
default="uwp_texture",
|
| 118 |
+
choices=["uwp_texture", "uwp_img"],
|
| 119 |
+
help="Which type of prediction to compare. Either the unwarped textures or the unwarped litted images.",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument("-v", "--verbose", action="store_true")
|
| 122 |
+
args = parser.parse_args()
|
| 123 |
+
|
| 124 |
+
compute_metrics(
|
| 125 |
+
uvdoc_path=os.path.abspath(args.uvdoc_path),
|
| 126 |
+
pred_path=os.path.abspath(args.pred_path),
|
| 127 |
+
pred_type=args.pred_type,
|
| 128 |
+
verbose=args.verbose,
|
| 129 |
+
)
|
UVDoc_official/uvdocBenchmark_metric.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from os.path import join as pjoin
|
| 4 |
+
|
| 5 |
+
import hdf5storage as h5
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from skimage.morphology import binary_erosion
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from utils import bilinear_unwarping_from_numpy
|
| 13 |
+
|
| 14 |
+
WIDTH = 1000
|
| 15 |
+
HEIGHT = 1000
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_vertical_stripe_texture(width, height, stripe_width=1, position=0):
|
| 19 |
+
"""
|
| 20 |
+
Create an image with a vertical stripe.
|
| 21 |
+
"""
|
| 22 |
+
im = np.ones((height, width, 3), dtype=np.uint8) * 255
|
| 23 |
+
im[:, position : position + stripe_width] = 0
|
| 24 |
+
return im
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_horizontal_stripe_texture(width, height, stripe_width=1, position=0):
|
| 28 |
+
"""
|
| 29 |
+
Create an image with a horizontal stripe.
|
| 30 |
+
"""
|
| 31 |
+
im = np.ones((height, width, 3), dtype=np.uint8) * 255
|
| 32 |
+
im[position : position + stripe_width, :] = 0
|
| 33 |
+
return im
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def warp_texture(texture, uvmap):
|
| 37 |
+
"""
|
| 38 |
+
Warp an input texture based on the provided uvmap.
|
| 39 |
+
"""
|
| 40 |
+
# Warp the texture based on the uv
|
| 41 |
+
torch_texture_unwarp = torch.from_numpy(np.expand_dims(texture.transpose(2, 0, 1), axis=0)).float()
|
| 42 |
+
uvmap_torch = torch.from_numpy(np.expand_dims(uvmap * 2 - 1, axis=0)).float()
|
| 43 |
+
warped_texture = F.grid_sample(torch_texture_unwarp, uvmap_torch, align_corners=False)
|
| 44 |
+
warped_texture = np.clip(warped_texture[0].numpy().transpose(1, 2, 0), 0, 255) / 255
|
| 45 |
+
|
| 46 |
+
# Postprocessing to have nicer results
|
| 47 |
+
grey = np.all(warped_texture == 0.5, axis=-1)
|
| 48 |
+
warped_texture[grey] = np.nan
|
| 49 |
+
mask = 1 - np.all(np.isnan(warped_texture), axis=-1).astype(int)
|
| 50 |
+
mask_small = binary_erosion(mask).astype(int)
|
| 51 |
+
mask_small = np.expand_dims(mask_small, axis=-1)
|
| 52 |
+
warped_texture[np.repeat(~mask_small.astype(bool), 3, axis=-1)] = 1
|
| 53 |
+
warped_texture = (warped_texture * 255).astype(np.uint8)
|
| 54 |
+
|
| 55 |
+
return warped_texture
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def compute_metric_single_line(uvmap, bm, pos, direction="horizontal"):
|
| 59 |
+
"""
|
| 60 |
+
Compute the line metric for a single line.
|
| 61 |
+
args:
|
| 62 |
+
uvmap: uvmap of the document, shape (height, width, 2)
|
| 63 |
+
bm: predicted backward mapping, shape (height, width, 2)
|
| 64 |
+
pos: position of the line to compute the metric
|
| 65 |
+
direction: direction of the line to compute the metric (horizontal or vertical)
|
| 66 |
+
"""
|
| 67 |
+
# Create the original straight line
|
| 68 |
+
if direction == "horizontal":
|
| 69 |
+
stripe = create_horizontal_stripe_texture(WIDTH, HEIGHT, stripe_width=1, position=pos)
|
| 70 |
+
elif direction == "vertical":
|
| 71 |
+
stripe = create_vertical_stripe_texture(WIDTH, HEIGHT, stripe_width=1, position=pos)
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError("Direction must be horizontal or vertical")
|
| 74 |
+
|
| 75 |
+
# Warp the stripe according to the ground truth uvmap and unwarp it according to the predicted bm
|
| 76 |
+
warped_stripe = warp_texture(stripe, uvmap)
|
| 77 |
+
unwarped_stripe = bilinear_unwarping_from_numpy(warped_stripe.astype(float) / 255.0, bm, (WIDTH, HEIGHT))
|
| 78 |
+
|
| 79 |
+
# Binarize the result
|
| 80 |
+
THRESH = 0.5
|
| 81 |
+
unwarped_stripe = unwarped_stripe[:, :, 0]
|
| 82 |
+
unwarped_stripe[unwarped_stripe < THRESH] = 0
|
| 83 |
+
unwarped_stripe[unwarped_stripe >= THRESH] = 1
|
| 84 |
+
|
| 85 |
+
# Find the black pixels
|
| 86 |
+
xs, ys = np.where(unwarped_stripe == 0)
|
| 87 |
+
if len(xs) == 0 or len(ys) == 0:
|
| 88 |
+
# No black pixels in the line, this means that the backward mapping is pretty bad
|
| 89 |
+
return np.nan
|
| 90 |
+
|
| 91 |
+
# Compute the metric
|
| 92 |
+
if direction == "horizontal":
|
| 93 |
+
return np.std(xs)
|
| 94 |
+
elif direction == "vertical":
|
| 95 |
+
return np.std(ys)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def compute_sample_line_metric(uvdoc_path, pred_path, sample, n_lines):
|
| 99 |
+
"""
|
| 100 |
+
Compute all lines metric for a given sample.
|
| 101 |
+
"""
|
| 102 |
+
# Load ground truth UVmap
|
| 103 |
+
metadata_path = pjoin(uvdoc_path, "metadata_sample", f"{sample}.json")
|
| 104 |
+
with open(metadata_path, "r") as f:
|
| 105 |
+
metadata = json.load(f)
|
| 106 |
+
uvmap_path = pjoin(uvdoc_path, "uvmap", f"{metadata['geom_name']}.mat")
|
| 107 |
+
uvmap = h5.loadmat(uvmap_path)["uv"]
|
| 108 |
+
|
| 109 |
+
# Load predicted backward mapping
|
| 110 |
+
bm_path = pjoin(pred_path, f"{sample}.mat")
|
| 111 |
+
bm = h5.loadmat(bm_path)["bm"]
|
| 112 |
+
|
| 113 |
+
# Compute metric
|
| 114 |
+
stds_hor = []
|
| 115 |
+
stds_ver = []
|
| 116 |
+
for pos in np.linspace(50, 950, n_lines, dtype=int):
|
| 117 |
+
uvmap = h5.loadmat(uvmap_path)["uv"]
|
| 118 |
+
stds_hor.append(compute_metric_single_line(uvmap, bm, pos, direction="horizontal"))
|
| 119 |
+
stds_ver.append(compute_metric_single_line(uvmap, bm, pos, direction="vertical"))
|
| 120 |
+
|
| 121 |
+
return np.nanmean(stds_hor), np.nanmean(stds_ver)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def compute_line_metric(uvdoc_path, pred_path, n_lines=25):
|
| 125 |
+
"""
|
| 126 |
+
Compute the line metric over the whole UVDoc dataset.
|
| 127 |
+
"""
|
| 128 |
+
# Find all samples
|
| 129 |
+
all_samples = sorted([x[:-4] for x in os.listdir(pjoin(uvdoc_path, "img"))])
|
| 130 |
+
|
| 131 |
+
# Compute the metric for each sample
|
| 132 |
+
lines = []
|
| 133 |
+
cols = []
|
| 134 |
+
for sample in tqdm(all_samples):
|
| 135 |
+
hor, ver = compute_sample_line_metric(uvdoc_path, pred_path, sample, n_lines)
|
| 136 |
+
lines.append(hor)
|
| 137 |
+
cols.append(ver)
|
| 138 |
+
|
| 139 |
+
# Saves all results including individual ones
|
| 140 |
+
with open(os.path.join(pred_path, "line_metric.json"), "w") as f:
|
| 141 |
+
json.dump(
|
| 142 |
+
{sample: {"hor": lines[i], "ver": cols[i]} for i, sample in enumerate(all_samples)},
|
| 143 |
+
f,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
with open(os.path.join(pred_path, "line_metric_mean.json"), "w") as f:
|
| 147 |
+
json.dump(
|
| 148 |
+
{"hor": np.mean(lines), "ver": np.mean(cols)},
|
| 149 |
+
f,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return np.mean(lines), np.mean(cols)
|
UVDoc_official/uvdocBenchmark_pred.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import hdf5storage as h5
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from utils import IMG_SIZE, bilinear_unwarping, load_model
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class UVDocBenchmarkLoader(torch.utils.data.Dataset):
|
| 14 |
+
"""
|
| 15 |
+
Torch dataset class for the UVDoc benchmark dataset.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
data_path,
|
| 21 |
+
img_size=(488, 712),
|
| 22 |
+
):
|
| 23 |
+
self.dataroot = data_path
|
| 24 |
+
self.im_list = os.listdir(os.path.join(self.dataroot, "img"))
|
| 25 |
+
self.img_size = img_size
|
| 26 |
+
|
| 27 |
+
def __len__(self):
|
| 28 |
+
return len(self.im_list)
|
| 29 |
+
|
| 30 |
+
def __getitem__(self, index):
|
| 31 |
+
im_name = self.im_list[index]
|
| 32 |
+
img_path = os.path.join(self.dataroot, "img", im_name)
|
| 33 |
+
img_RGB = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
| 34 |
+
img_RGB = torch.from_numpy(cv2.resize(img_RGB, self.img_size).transpose(2, 0, 1))
|
| 35 |
+
return img_RGB, im_name
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def infer_uvdoc(model, dataloader, device, save_path):
|
| 39 |
+
"""
|
| 40 |
+
Unwarp all images in the UVDoc benchmark and save them, along with the mappings.
|
| 41 |
+
"""
|
| 42 |
+
model.eval()
|
| 43 |
+
|
| 44 |
+
os.makedirs(os.path.join(save_path, "uwp_img"), exist_ok=True)
|
| 45 |
+
os.makedirs(os.path.join(save_path, "bm"), exist_ok=True)
|
| 46 |
+
os.makedirs(os.path.join(save_path, "uwp_texture"), exist_ok=True)
|
| 47 |
+
|
| 48 |
+
for img_RGB, im_names in tqdm(dataloader):
|
| 49 |
+
# Inference
|
| 50 |
+
img_RGB = img_RGB.to(device)
|
| 51 |
+
point_positions2D, _ = model(img_RGB)
|
| 52 |
+
|
| 53 |
+
# Warped image need to be re-open to get full resolution (downsampled in data loader)
|
| 54 |
+
warped = cv2.imread(os.path.join(dataloader.dataset.dataroot, "img", im_names[0]))
|
| 55 |
+
warped = cv2.cvtColor(warped, cv2.COLOR_BGR2RGB)
|
| 56 |
+
warped = torch.from_numpy(warped.transpose(2, 0, 1) / 255.0).float()
|
| 57 |
+
size = warped.shape[1:][::-1]
|
| 58 |
+
|
| 59 |
+
# Unwarping
|
| 60 |
+
unwarped = bilinear_unwarping(
|
| 61 |
+
warped_img=torch.unsqueeze(warped, dim=0).to(device),
|
| 62 |
+
point_positions=torch.unsqueeze(point_positions2D[0], dim=0),
|
| 63 |
+
img_size=tuple(size),
|
| 64 |
+
)
|
| 65 |
+
unwarped = (unwarped[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
| 66 |
+
unwarped_BGR = cv2.cvtColor(unwarped, cv2.COLOR_RGB2BGR)
|
| 67 |
+
|
| 68 |
+
cv2.imwrite(
|
| 69 |
+
os.path.join(save_path, "uwp_img", im_names[0].split(" ")[0].split(".")[0] + ".png"),
|
| 70 |
+
unwarped_BGR,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Unwarp and save the texture
|
| 74 |
+
warp_texture = cv2.imread(os.path.join(dataloader.dataset.dataroot, "warped_textures", im_names[0]))
|
| 75 |
+
warp_texture = cv2.cvtColor(warp_texture, cv2.COLOR_BGR2RGB)
|
| 76 |
+
warp_texture = torch.from_numpy(warp_texture.transpose(2, 0, 1) / 255.0).float()
|
| 77 |
+
size = warp_texture.shape[1:][::-1]
|
| 78 |
+
|
| 79 |
+
unwarped_texture = bilinear_unwarping(
|
| 80 |
+
warped_img=torch.unsqueeze(warp_texture, dim=0).to(device),
|
| 81 |
+
point_positions=torch.unsqueeze(point_positions2D[0], dim=0),
|
| 82 |
+
img_size=tuple(size),
|
| 83 |
+
)
|
| 84 |
+
unwarped_texture = (unwarped_texture[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
| 85 |
+
unwarped_texture_BGR = cv2.cvtColor(unwarped_texture, cv2.COLOR_RGB2BGR)
|
| 86 |
+
|
| 87 |
+
cv2.imwrite(
|
| 88 |
+
os.path.join(save_path, "uwp_texture", im_names[0].split(" ")[0].split(".")[0] + ".png"),
|
| 89 |
+
unwarped_texture_BGR,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Save Backward Map
|
| 93 |
+
h5.savemat(
|
| 94 |
+
os.path.join(save_path, "bm", im_names[0].split(" ")[0].split(".")[0] + ".mat"),
|
| 95 |
+
{"bm": point_positions2D[0].detach().cpu().numpy().transpose(1, 2, 0)},
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def create_uvdoc_results(ckpt_path, uvdoc_path, img_size):
|
| 100 |
+
"""
|
| 101 |
+
Create results for the UVDoc benchmark.
|
| 102 |
+
"""
|
| 103 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 104 |
+
|
| 105 |
+
# Load model, create dataset and save directory
|
| 106 |
+
model = load_model(ckpt_path)
|
| 107 |
+
model.to(device)
|
| 108 |
+
|
| 109 |
+
dataset = UVDocBenchmarkLoader(data_path=uvdoc_path, img_size=img_size)
|
| 110 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False)
|
| 111 |
+
|
| 112 |
+
save_path = os.path.join("/".join(ckpt_path.split("/")[:-1]), "output_uvdoc")
|
| 113 |
+
os.makedirs(save_path, exist_ok=True)
|
| 114 |
+
print(f" Results will be saved at {save_path}", flush=True)
|
| 115 |
+
|
| 116 |
+
# Infer results
|
| 117 |
+
infer_uvdoc(model, dataloader, "cuda:0", save_path)
|
| 118 |
+
return save_path
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
parser = argparse.ArgumentParser()
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--ckpt-path", type=str, default="./model/best_model.pkl", help="Path to the model weights as pkl."
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--uvdoc-path", type=str, default="./data/UVDoc_benchmark/", help="Path to the UVDocBenchmark dataset."
|
| 128 |
+
)
|
| 129 |
+
args = parser.parse_args()
|
| 130 |
+
|
| 131 |
+
create_uvdoc_results(args.ckpt_path, os.path.abspath(args.uvdoc_path), IMG_SIZE)
|
UVDoc_official/verify_ckpt_val_pipeline.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Run inference with the *same* preprocessing as train.py validation (UVDocDataset).
|
| 4 |
+
|
| 5 |
+
Official demo.py resizes the full image only; training uses tight crop + resize.
|
| 6 |
+
Use this script to check train/val vs inference consistency on UVDoc.
|
| 7 |
+
|
| 8 |
+
Mean MSE printed at the end matches train.py val when batch_size divides val set evenly
|
| 9 |
+
(mean of per-batch MSE with default MSELoss reduction='mean' equals mean per-image MSE here).
|
| 10 |
+
"""
|
| 11 |
+
import argparse
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
import cv2
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
import data_UVDoc
|
| 20 |
+
import utils
|
| 21 |
+
from model import UVDocnet
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main():
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument("--ckpt", type=str, required=True, help="Path to ep_*_best_model.pkl")
|
| 27 |
+
parser.add_argument("--data_path_UVDoc", type=str, required=True)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--overfit_n",
|
| 30 |
+
type=int,
|
| 31 |
+
default=0,
|
| 32 |
+
help="If >0, same as train.py: first N sorted samples, deterministic crop.",
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--uvdoc_val_ratio",
|
| 36 |
+
type=float,
|
| 37 |
+
default=0.05,
|
| 38 |
+
help="Must match train.py when using split=val (not overfit).",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--uvdoc_split_seed",
|
| 42 |
+
type=int,
|
| 43 |
+
default=42,
|
| 44 |
+
help="Must match train.py when using split=val.",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--uvdoc_split_mode",
|
| 48 |
+
type=str,
|
| 49 |
+
default="sample",
|
| 50 |
+
choices=["sample", "geom"],
|
| 51 |
+
help="Must match train.py when using split=val.",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--uvdoc_grid3d_stats",
|
| 55 |
+
type=str,
|
| 56 |
+
default=None,
|
| 57 |
+
help="Must match train.py if used during training.",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument("--out_dir", type=str, required=True, help="Directory for metrics.txt (and PNGs unless --no_save_images).")
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--no_save_images",
|
| 62 |
+
action="store_true",
|
| 63 |
+
help="Only write metrics.txt + mean MSE (for full val set, avoids huge I/O).",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--max_save_images",
|
| 67 |
+
type=int,
|
| 68 |
+
default=0,
|
| 69 |
+
help="If >0, save at most this many samples' PNG triplets (first indices).",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
|
| 74 |
+
device = torch.device(args.device)
|
| 75 |
+
|
| 76 |
+
if args.overfit_n and int(args.overfit_n) > 0:
|
| 77 |
+
ds = data_UVDoc.UVDocDataset(
|
| 78 |
+
data_path=args.data_path_UVDoc,
|
| 79 |
+
appearance_augmentation=[],
|
| 80 |
+
geometric_augmentations=[],
|
| 81 |
+
overfit=True,
|
| 82 |
+
max_samples=int(args.overfit_n),
|
| 83 |
+
deterministic_crop=True,
|
| 84 |
+
grid3d_stats_path=args.uvdoc_grid3d_stats,
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
ds = data_UVDoc.UVDocDataset(
|
| 88 |
+
data_path=args.data_path_UVDoc,
|
| 89 |
+
appearance_augmentation=[],
|
| 90 |
+
geometric_augmentations=[],
|
| 91 |
+
split="val",
|
| 92 |
+
val_ratio=float(args.uvdoc_val_ratio),
|
| 93 |
+
split_seed=int(args.uvdoc_split_seed),
|
| 94 |
+
split_mode=args.uvdoc_split_mode,
|
| 95 |
+
grid3d_stats_path=args.uvdoc_grid3d_stats,
|
| 96 |
+
deterministic_crop=True,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
| 100 |
+
|
| 101 |
+
net = UVDocnet(num_filter=32, kernel_size=5)
|
| 102 |
+
ckpt = torch.load(args.ckpt, map_location=device)
|
| 103 |
+
net.load_state_dict(ckpt["model_state"])
|
| 104 |
+
net.to(device)
|
| 105 |
+
net.eval()
|
| 106 |
+
|
| 107 |
+
criterion_mse = nn.MSELoss()
|
| 108 |
+
|
| 109 |
+
lines = []
|
| 110 |
+
mses = []
|
| 111 |
+
n_save = 0
|
| 112 |
+
save_cap = args.max_save_images if args.max_save_images and args.max_save_images > 0 else None
|
| 113 |
+
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
for idx in range(len(ds)):
|
| 116 |
+
img_w, img_uw_gt, _, _ = ds[idx]
|
| 117 |
+
x = img_w.unsqueeze(0).to(device)
|
| 118 |
+
gt = img_uw_gt.unsqueeze(0).to(device)
|
| 119 |
+
|
| 120 |
+
g2d, _ = net(x)
|
| 121 |
+
pred_uw = utils.bilinear_unwarping(x, g2d, tuple(utils.IMG_SIZE))
|
| 122 |
+
|
| 123 |
+
mse = float(criterion_mse(pred_uw, gt))
|
| 124 |
+
mses.append(mse)
|
| 125 |
+
sid = ds.all_samples[idx]
|
| 126 |
+
lines.append(f"{sid} mse={mse:.6f}\n")
|
| 127 |
+
|
| 128 |
+
do_save = not args.no_save_images and (save_cap is None or n_save < save_cap)
|
| 129 |
+
if do_save:
|
| 130 |
+
|
| 131 |
+
def to_bgr_u8(t):
|
| 132 |
+
a = (t.squeeze(0).cpu().numpy().transpose(1, 2, 0) * 255.0).clip(0, 255).astype(np.uint8)
|
| 133 |
+
return cv2.cvtColor(a, cv2.COLOR_RGB2BGR)
|
| 134 |
+
|
| 135 |
+
cv2.imwrite(os.path.join(args.out_dir, f"{sid}_in_warped.png"), to_bgr_u8(x))
|
| 136 |
+
cv2.imwrite(os.path.join(args.out_dir, f"{sid}_gt_unwarp.png"), to_bgr_u8(gt))
|
| 137 |
+
cv2.imwrite(os.path.join(args.out_dir, f"{sid}_pred_unwarp.png"), to_bgr_u8(pred_uw))
|
| 138 |
+
n_save += 1
|
| 139 |
+
|
| 140 |
+
mean_mse = float(np.mean(mses)) if mses else 0.0
|
| 141 |
+
lines.append(f"mean_mse {mean_mse:.8f} n={len(mses)}\n")
|
| 142 |
+
|
| 143 |
+
with open(os.path.join(args.out_dir, "metrics.txt"), "w") as f:
|
| 144 |
+
f.writelines(lines)
|
| 145 |
+
|
| 146 |
+
print(f"Wrote metrics for {len(mses)} samples to {args.out_dir}")
|
| 147 |
+
print(f"mean_mse={mean_mse:.8f} (compare to train log Val MSE for the same checkpoint epoch)")
|
| 148 |
+
if not args.no_save_images and save_cap is not None:
|
| 149 |
+
print(f"Saved PNG triplets for first {n_save} samples (max_save_images={save_cap}).")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
main()
|
UVDoc_official/verify_uvdoc_train_infer_preprocess.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Assert UVDoc preprocessing is identical between:
|
| 4 |
+
|
| 5 |
+
- train.py (overfit / val UVDocDataset kwargs)
|
| 6 |
+
- verify_ckpt_val_pipeline.py (same kwargs)
|
| 7 |
+
|
| 8 |
+
This catches silent drift if one path changes crop/aug/split/grid3d stats.
|
| 9 |
+
|
| 10 |
+
Optional: compare DataLoader batch tensors to manual stack (shuffle=False).
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch.utils.data import DataLoader
|
| 20 |
+
|
| 21 |
+
import data_UVDoc
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _ds_train_overfit(args) -> data_UVDoc.UVDocDataset:
|
| 25 |
+
"""Mirror train.py setup_data() when overfit_n > 0 (train branch)."""
|
| 26 |
+
return data_UVDoc.UVDocDataset(
|
| 27 |
+
data_path=args.data_path_UVDoc,
|
| 28 |
+
appearance_augmentation=[],
|
| 29 |
+
geometric_augmentations=[],
|
| 30 |
+
overfit=True,
|
| 31 |
+
max_samples=int(args.overfit_n),
|
| 32 |
+
deterministic_crop=True,
|
| 33 |
+
grid3d_stats_path=args.uvdoc_grid3d_stats,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _ds_verify_overfit(args) -> data_UVDoc.UVDocDataset:
|
| 38 |
+
"""Mirror verify_ckpt_val_pipeline.py when overfit_n > 0."""
|
| 39 |
+
return data_UVDoc.UVDocDataset(
|
| 40 |
+
data_path=args.data_path_UVDoc,
|
| 41 |
+
appearance_augmentation=[],
|
| 42 |
+
geometric_augmentations=[],
|
| 43 |
+
overfit=True,
|
| 44 |
+
max_samples=int(args.overfit_n),
|
| 45 |
+
deterministic_crop=True,
|
| 46 |
+
grid3d_stats_path=args.uvdoc_grid3d_stats,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _ds_train_val_split(args) -> data_UVDoc.UVDocDataset:
|
| 51 |
+
"""Mirror train.py val loader (non-overfit)."""
|
| 52 |
+
return data_UVDoc.UVDocDataset(
|
| 53 |
+
data_path=args.data_path_UVDoc,
|
| 54 |
+
appearance_augmentation=[],
|
| 55 |
+
geometric_augmentations=[],
|
| 56 |
+
split="val",
|
| 57 |
+
val_ratio=float(args.uvdoc_val_ratio),
|
| 58 |
+
split_seed=int(args.uvdoc_split_seed),
|
| 59 |
+
split_mode=args.uvdoc_split_mode,
|
| 60 |
+
grid3d_stats_path=args.uvdoc_grid3d_stats,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _ds_verify_val_split(args) -> data_UVDoc.UVDocDataset:
|
| 65 |
+
"""Mirror verify_ckpt_val_pipeline.py val branch (explicit deterministic_crop)."""
|
| 66 |
+
return data_UVDoc.UVDocDataset(
|
| 67 |
+
data_path=args.data_path_UVDoc,
|
| 68 |
+
appearance_augmentation=[],
|
| 69 |
+
geometric_augmentations=[],
|
| 70 |
+
split="val",
|
| 71 |
+
val_ratio=float(args.uvdoc_val_ratio),
|
| 72 |
+
split_seed=int(args.uvdoc_split_seed),
|
| 73 |
+
split_mode=args.uvdoc_split_mode,
|
| 74 |
+
grid3d_stats_path=args.uvdoc_grid3d_stats,
|
| 75 |
+
deterministic_crop=True,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _assert_close(name: str, a: torch.Tensor, b: torch.Tensor, rtol: float, atol: float) -> None:
|
| 80 |
+
if a.shape != b.shape:
|
| 81 |
+
raise AssertionError(f"{name}: shape mismatch {tuple(a.shape)} vs {tuple(b.shape)}")
|
| 82 |
+
if not torch.allclose(a, b, rtol=rtol, atol=atol):
|
| 83 |
+
d = (a.float() - b.float()).abs().max().item()
|
| 84 |
+
raise AssertionError(f"{name}: max abs diff {d} (rtol={rtol}, atol={atol})")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _compare_sample(ds_a, ds_b, idx: int, rtol: float, atol: float) -> None:
|
| 88 |
+
wa, ua, g2a, g3a = ds_a[idx]
|
| 89 |
+
wb, ub, g2b, g3b = ds_b[idx]
|
| 90 |
+
sid_a = ds_a.all_samples[idx]
|
| 91 |
+
sid_b = ds_b.all_samples[idx]
|
| 92 |
+
if sid_a != sid_b:
|
| 93 |
+
raise AssertionError(f"index {idx}: sample id mismatch {sid_a!r} vs {sid_b!r}")
|
| 94 |
+
_assert_close(f"[{sid_a}] warped", wa, wb, rtol, atol)
|
| 95 |
+
_assert_close(f"[{sid_a}] unwarped_gt", ua, ub, rtol, atol)
|
| 96 |
+
_assert_close(f"[{sid_a}] grid2d", g2a, g2b, rtol, atol)
|
| 97 |
+
_assert_close(f"[{sid_a}] grid3d", g3a, g3b, rtol, atol)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def main() -> int:
|
| 101 |
+
p = argparse.ArgumentParser()
|
| 102 |
+
p.add_argument("--data_path_UVDoc", type=str, required=True)
|
| 103 |
+
p.add_argument("--overfit_n", type=int, default=1, help=">0: compare overfit train vs verify constructors.")
|
| 104 |
+
p.add_argument(
|
| 105 |
+
"--mode",
|
| 106 |
+
type=str,
|
| 107 |
+
default="overfit",
|
| 108 |
+
choices=["overfit", "val_split", "both"],
|
| 109 |
+
help="overfit: first N sorted ids; val_split: hold-out val set kwargs alignment.",
|
| 110 |
+
)
|
| 111 |
+
p.add_argument("--uvdoc_val_ratio", type=float, default=0.05)
|
| 112 |
+
p.add_argument("--uvdoc_split_seed", type=int, default=42)
|
| 113 |
+
p.add_argument("--uvdoc_split_mode", type=str, default="sample", choices=["sample", "geom"])
|
| 114 |
+
p.add_argument("--uvdoc_grid3d_stats", type=str, default=None)
|
| 115 |
+
p.add_argument("--rtol", type=float, default=1e-6)
|
| 116 |
+
p.add_argument("--atol", type=float, default=1e-6)
|
| 117 |
+
p.add_argument(
|
| 118 |
+
"--check_dataloader",
|
| 119 |
+
action="store_true",
|
| 120 |
+
help="Stack dataset tensors and compare to first batch (shuffle=False, matches index order).",
|
| 121 |
+
)
|
| 122 |
+
p.add_argument("--batch_size", type=int, default=8)
|
| 123 |
+
p.add_argument("--num_workers", type=int, default=0)
|
| 124 |
+
args = p.parse_args()
|
| 125 |
+
|
| 126 |
+
if not os.path.isdir(args.data_path_UVDoc):
|
| 127 |
+
print(f"data_path_UVDoc is not a directory: {args.data_path_UVDoc}", file=sys.stderr)
|
| 128 |
+
return 2
|
| 129 |
+
|
| 130 |
+
modes = ["overfit", "val_split"] if args.mode == "both" else [args.mode]
|
| 131 |
+
|
| 132 |
+
for m in modes:
|
| 133 |
+
if m == "overfit":
|
| 134 |
+
if int(args.overfit_n) <= 0:
|
| 135 |
+
print("--overfit_n must be > 0 for mode=overfit", file=sys.stderr)
|
| 136 |
+
return 2
|
| 137 |
+
a = _ds_train_overfit(args)
|
| 138 |
+
b = _ds_verify_overfit(args)
|
| 139 |
+
tag = f"overfit_n={args.overfit_n}"
|
| 140 |
+
else:
|
| 141 |
+
a = _ds_train_val_split(args)
|
| 142 |
+
b = _ds_verify_val_split(args)
|
| 143 |
+
tag = "val_split"
|
| 144 |
+
|
| 145 |
+
if len(a) != len(b):
|
| 146 |
+
raise AssertionError(f"{tag}: len mismatch {len(a)} vs {len(b)}")
|
| 147 |
+
for i in range(len(a)):
|
| 148 |
+
_compare_sample(a, b, i, args.rtol, args.atol)
|
| 149 |
+
|
| 150 |
+
if args.check_dataloader:
|
| 151 |
+
loader = DataLoader(a, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=False)
|
| 152 |
+
imgs_w, imgs_uw, g2, g3 = next(iter(loader))
|
| 153 |
+
n = imgs_w.shape[0]
|
| 154 |
+
w_stack = torch.stack([a[i][0] for i in range(n)], dim=0)
|
| 155 |
+
uw_stack = torch.stack([a[i][1] for i in range(n)], dim=0)
|
| 156 |
+
g2_stack = torch.stack([a[i][2] for i in range(n)], dim=0)
|
| 157 |
+
g3_stack = torch.stack([a[i][3] for i in range(n)], dim=0)
|
| 158 |
+
_assert_close(f"{tag} dataloader warped", imgs_w, w_stack, args.rtol, args.atol)
|
| 159 |
+
_assert_close(f"{tag} dataloader unwarped", imgs_uw, uw_stack, args.rtol, args.atol)
|
| 160 |
+
_assert_close(f"{tag} dataloader grid2d", g2, g2_stack, args.rtol, args.atol)
|
| 161 |
+
_assert_close(f"{tag} dataloader grid3d", g3, g3_stack, args.rtol, args.atol)
|
| 162 |
+
|
| 163 |
+
print(f"OK: {tag} — train.py vs verify_ckpt_val_pipeline.py UVDocDataset tensors match ({len(a)} samples).")
|
| 164 |
+
|
| 165 |
+
return 0
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
raise SystemExit(main())
|
baseline_resnet_unet/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ResNet50 + U-Net UV baseline for UVDoc (see tech note in repo)."""
|
| 2 |
+
|
| 3 |
+
from .model import ResNet50UNetUV
|
| 4 |
+
|
| 5 |
+
__all__ = ["ResNet50UNetUV"]
|
baseline_resnet_unet/dataset.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
from os.path import join as pjoin
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import h5py
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from torch.utils.data import Dataset
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import albumentations as A
|
| 16 |
+
except ImportError: # pragma: no cover
|
| 17 |
+
A = None
|
| 18 |
+
|
| 19 |
+
ORIGINAL_GRID_SIZE = (89, 61)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _appearance_compose(names: list[str]):
|
| 23 |
+
if not names or A is None:
|
| 24 |
+
return None
|
| 25 |
+
transforms = []
|
| 26 |
+
if "visual" in names:
|
| 27 |
+
transforms.append(
|
| 28 |
+
A.OneOf(
|
| 29 |
+
[A.ToSepia(p=15), A.ToGray(p=20), A.Equalize(p=15), A.Sharpen(p=20)],
|
| 30 |
+
p=0.5,
|
| 31 |
+
)
|
| 32 |
+
)
|
| 33 |
+
if "noise" in names:
|
| 34 |
+
transforms.append(
|
| 35 |
+
A.OneOf(
|
| 36 |
+
[
|
| 37 |
+
A.GaussNoise(var_limit=(10.0, 20.0), p=70),
|
| 38 |
+
A.ISONoise(intensity=(0.1, 0.25), p=30),
|
| 39 |
+
],
|
| 40 |
+
p=0.6,
|
| 41 |
+
)
|
| 42 |
+
)
|
| 43 |
+
if "color" in names:
|
| 44 |
+
transforms.append(
|
| 45 |
+
A.OneOf(
|
| 46 |
+
[
|
| 47 |
+
A.ColorJitter(p=5),
|
| 48 |
+
A.HueSaturationValue(p=10),
|
| 49 |
+
A.RandomBrightnessContrast(brightness_limit=[-0.05, 0.25], p=85),
|
| 50 |
+
],
|
| 51 |
+
p=0.95,
|
| 52 |
+
)
|
| 53 |
+
)
|
| 54 |
+
return A.Compose(transforms=transforms) if transforms else None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def crop_image_tight(img: np.ndarray, grid2d: np.ndarray, deterministic: bool) -> tuple[np.ndarray, int, int, int, int]:
|
| 58 |
+
"""grid2d: 2 x Gh x Gw in pixel coords of img."""
|
| 59 |
+
size = img.shape
|
| 60 |
+
minx = int(np.floor(np.amin(grid2d[0, :, :])))
|
| 61 |
+
maxx = int(np.ceil(np.amax(grid2d[0, :, :])))
|
| 62 |
+
miny = int(np.floor(np.amin(grid2d[1, :, :])))
|
| 63 |
+
maxy = int(np.ceil(np.amax(grid2d[1, :, :])))
|
| 64 |
+
s = 20
|
| 65 |
+
s = min(min(s, minx), miny)
|
| 66 |
+
s = min(min(s, size[1] - 1 - maxx), size[0] - 1 - maxy)
|
| 67 |
+
|
| 68 |
+
img = img[miny - s : maxy + s, minx - s : maxx + s, :]
|
| 69 |
+
if deterministic:
|
| 70 |
+
cx1 = cy1 = max((s - 5) // 2, 0)
|
| 71 |
+
cx2 = cy2 = max((s - 5) // 2, 0) + 1
|
| 72 |
+
else:
|
| 73 |
+
cx1 = random.randint(0, max(s - 5, 1))
|
| 74 |
+
cx2 = random.randint(0, max(s - 5, 1)) + 1
|
| 75 |
+
cy1 = random.randint(0, max(s - 5, 1))
|
| 76 |
+
cy2 = random.randint(0, max(s - 5, 1)) + 1
|
| 77 |
+
|
| 78 |
+
img = img[cy1:-cy2, cx1:-cx2, :]
|
| 79 |
+
top = miny - s + cy1
|
| 80 |
+
bot = size[0] - maxy - s + cy2
|
| 81 |
+
left = minx - s + cx1
|
| 82 |
+
right = size[1] - maxx - s + cx2
|
| 83 |
+
return img, top, bot, left, right
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def crop_tight_resize(
|
| 87 |
+
img_rgb: np.ndarray,
|
| 88 |
+
grid2d: np.ndarray,
|
| 89 |
+
out_wh: tuple[int, int],
|
| 90 |
+
deterministic: bool,
|
| 91 |
+
) -> tuple[torch.Tensor, np.ndarray]:
|
| 92 |
+
"""Returns image tensor CxHxW float [0,255] and normalized grid2d 2xGh x Gw in [-1,1]."""
|
| 93 |
+
size = img_rgb.shape
|
| 94 |
+
img, top, bot, left, right = crop_image_tight(img_rgb, grid2d, deterministic)
|
| 95 |
+
img = cv2.resize(img, out_wh)
|
| 96 |
+
img = img.transpose(2, 0, 1)
|
| 97 |
+
img_t = torch.from_numpy(img).float()
|
| 98 |
+
|
| 99 |
+
grid2d = grid2d.copy()
|
| 100 |
+
grid2d[0, :, :] = (grid2d[0, :, :] - left) / (size[1] - left - right)
|
| 101 |
+
grid2d[1, :, :] = (grid2d[1, :, :] - top) / (size[0] - top - bot)
|
| 102 |
+
grid2d = (grid2d * 2.0) - 1.0
|
| 103 |
+
return img_t, grid2d
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class UVDocDenseUVDataset(Dataset):
|
| 107 |
+
"""
|
| 108 |
+
UVDoc final dataset: img/*.png + grid2d from metadata geom + HDF5 grid2d.
|
| 109 |
+
Produces dense UV GT at (out_h, out_w) via bilinear upsampling (official unwarp convention).
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
root: str,
|
| 115 |
+
split: str = "train",
|
| 116 |
+
out_hw: tuple[int, int] = (256, 256),
|
| 117 |
+
appearance_augmentation: list[str] | None = None,
|
| 118 |
+
deterministic_crop: bool | None = None,
|
| 119 |
+
) -> None:
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.root = root
|
| 122 |
+
self.out_hw = out_hw
|
| 123 |
+
self.out_wh = (out_hw[1], out_hw[0]) # cv2 (W,H)
|
| 124 |
+
appearance_augmentation = appearance_augmentation or []
|
| 125 |
+
self.appearance = _appearance_compose(appearance_augmentation)
|
| 126 |
+
split_path = pjoin(root, "split.json")
|
| 127 |
+
all_ids = [x[:-4] for x in os.listdir(pjoin(root, "img")) if x.endswith(".png")]
|
| 128 |
+
all_ids.sort()
|
| 129 |
+
if os.path.isfile(split_path):
|
| 130 |
+
with open(split_path, "r", encoding="utf-8") as f:
|
| 131 |
+
sp = json.load(f)
|
| 132 |
+
key = "train" if split == "train" else "val"
|
| 133 |
+
allowed = set(str(x) for x in sp.get(key, []))
|
| 134 |
+
picked = [i for i in all_ids if i in allowed]
|
| 135 |
+
if not picked and split == "val" and all_ids:
|
| 136 |
+
n = max(1, len(all_ids) // 10)
|
| 137 |
+
picked = all_ids[-n:]
|
| 138 |
+
if not picked:
|
| 139 |
+
picked = list(all_ids)
|
| 140 |
+
self.sample_ids = picked
|
| 141 |
+
else:
|
| 142 |
+
self.sample_ids = all_ids
|
| 143 |
+
|
| 144 |
+
if deterministic_crop is None:
|
| 145 |
+
deterministic_crop = split != "train"
|
| 146 |
+
self.deterministic_crop = deterministic_crop
|
| 147 |
+
|
| 148 |
+
def __len__(self) -> int:
|
| 149 |
+
return len(self.sample_ids)
|
| 150 |
+
|
| 151 |
+
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
| 152 |
+
sample_id = self.sample_ids[index]
|
| 153 |
+
with open(pjoin(self.root, "metadata_sample", f"{sample_id}.json"), "r", encoding="utf-8") as f:
|
| 154 |
+
geom_name = json.load(f)["geom_name"]
|
| 155 |
+
|
| 156 |
+
img_path = pjoin(self.root, "img", f"{sample_id}.png")
|
| 157 |
+
grid2d_path = pjoin(self.root, "grid2d", f"{geom_name}.mat")
|
| 158 |
+
|
| 159 |
+
with h5py.File(grid2d_path, "r") as f:
|
| 160 |
+
grid2d = np.array(f["grid2d"][:].T.transpose(2, 0, 1), dtype=np.float32)
|
| 161 |
+
|
| 162 |
+
img_bgr = cv2.imread(img_path)
|
| 163 |
+
if img_bgr is None:
|
| 164 |
+
raise FileNotFoundError(img_path)
|
| 165 |
+
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
| 166 |
+
|
| 167 |
+
if self.appearance is not None:
|
| 168 |
+
img_rgb = self.appearance(image=img_rgb)["image"]
|
| 169 |
+
|
| 170 |
+
img_t, grid2d_n = crop_tight_resize(
|
| 171 |
+
img_rgb, grid2d, self.out_wh, deterministic=self.deterministic_crop
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
grid_t = torch.from_numpy(grid2d_n).float()
|
| 175 |
+
grid_dense = torch.nn.functional.interpolate(
|
| 176 |
+
grid_t.unsqueeze(0),
|
| 177 |
+
size=self.out_hw,
|
| 178 |
+
mode="bilinear",
|
| 179 |
+
align_corners=True,
|
| 180 |
+
).squeeze(0)
|
| 181 |
+
|
| 182 |
+
img_bchw = (img_t / 255.0).unsqueeze(0)
|
| 183 |
+
unwarped = torch.nn.functional.grid_sample(
|
| 184 |
+
img_bchw,
|
| 185 |
+
grid_dense.permute(1, 2, 0).unsqueeze(0),
|
| 186 |
+
mode="bilinear",
|
| 187 |
+
padding_mode="border",
|
| 188 |
+
align_corners=True,
|
| 189 |
+
).squeeze(0)
|
| 190 |
+
|
| 191 |
+
im = img_t / 255.0
|
| 192 |
+
return {
|
| 193 |
+
"warped": im,
|
| 194 |
+
"uv_gt": grid_dense,
|
| 195 |
+
"unwarped": unwarped,
|
| 196 |
+
"sample_id": sample_id,
|
| 197 |
+
}
|
baseline_resnet_unet/model.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torchvision.models import ResNet50_Weights, resnet50
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ConvBNReLU(nn.Module):
|
| 10 |
+
def __init__(self, in_ch: int, out_ch: int) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.net = nn.Sequential(
|
| 13 |
+
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
| 14 |
+
nn.BatchNorm2d(out_ch),
|
| 15 |
+
nn.ReLU(inplace=True),
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 19 |
+
return self.net(x)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class UpSkip(nn.Module):
|
| 23 |
+
def __init__(self, in_ch: int, skip_ch: int, out_ch: int) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
|
| 26 |
+
self.conv = nn.Sequential(
|
| 27 |
+
ConvBNReLU(out_ch + skip_ch, out_ch),
|
| 28 |
+
ConvBNReLU(out_ch, out_ch),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
x = self.up(x)
|
| 33 |
+
if x.shape[-2:] != skip.shape[-2:]:
|
| 34 |
+
x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=True)
|
| 35 |
+
x = torch.cat([x, skip], dim=1)
|
| 36 |
+
return self.conv(x)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class UpNoSkip(nn.Module):
|
| 40 |
+
def __init__(self, in_ch: int, out_ch: int) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
|
| 43 |
+
self.conv = nn.Sequential(ConvBNReLU(out_ch, out_ch), ConvBNReLU(out_ch, out_ch))
|
| 44 |
+
|
| 45 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
return self.conv(self.up(x))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ResNet50UNetUV(nn.Module):
|
| 50 |
+
"""
|
| 51 |
+
ImageNet ResNet50 encoder + U-Net-style decoder.
|
| 52 |
+
Output: B×2×H×W UV in [-1, 1] (Tanh), aligned with UVDoc grid_sample convention.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, pretrained: bool = True) -> None:
|
| 56 |
+
super().__init__()
|
| 57 |
+
w = ResNet50_Weights.IMAGENET1K_V1 if pretrained else None
|
| 58 |
+
backbone = resnet50(weights=w)
|
| 59 |
+
|
| 60 |
+
self.conv1 = backbone.conv1
|
| 61 |
+
self.bn1 = backbone.bn1
|
| 62 |
+
self.relu = backbone.relu
|
| 63 |
+
self.maxpool = backbone.maxpool
|
| 64 |
+
self.layer1 = backbone.layer1
|
| 65 |
+
self.layer2 = backbone.layer2
|
| 66 |
+
self.layer3 = backbone.layer3
|
| 67 |
+
self.layer4 = backbone.layer4
|
| 68 |
+
|
| 69 |
+
self.dec43 = UpSkip(2048, 1024, 1024)
|
| 70 |
+
self.dec32 = UpSkip(1024, 512, 512)
|
| 71 |
+
self.dec21 = UpSkip(512, 256, 256)
|
| 72 |
+
self.dec10 = UpNoSkip(256, 128)
|
| 73 |
+
self.dec00 = UpNoSkip(128, 64)
|
| 74 |
+
self.head = nn.Conv2d(64, 2, kernel_size=1)
|
| 75 |
+
self.act = nn.Tanh()
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
x0 = self.relu(self.bn1(self.conv1(x)))
|
| 79 |
+
x0 = self.maxpool(x0)
|
| 80 |
+
s1 = self.layer1(x0)
|
| 81 |
+
s2 = self.layer2(s1)
|
| 82 |
+
s3 = self.layer3(s2)
|
| 83 |
+
s4 = self.layer4(s3)
|
| 84 |
+
d3 = self.dec43(s4, s3)
|
| 85 |
+
d2 = self.dec32(d3, s2)
|
| 86 |
+
d1 = self.dec21(d2, s1)
|
| 87 |
+
d0 = self.dec10(d1)
|
| 88 |
+
d = self.dec00(d0)
|
| 89 |
+
return self.act(self.head(d))
|
baseline_resnet_unet/train.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
|
| 11 |
+
from .dataset import UVDocDenseUVDataset
|
| 12 |
+
from .model import ResNet50UNetUV
|
| 13 |
+
from .warp import grid_sample_unwarp
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
| 17 |
+
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def collate_batch(samples: list[dict]) -> dict[str, torch.Tensor]:
|
| 21 |
+
return {
|
| 22 |
+
"warped": torch.stack([s["warped"] for s in samples], dim=0),
|
| 23 |
+
"uv_gt": torch.stack([s["uv_gt"] for s in samples], dim=0),
|
| 24 |
+
"unwarped": torch.stack([s["unwarped"] for s in samples], dim=0),
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def run_epoch(
|
| 29 |
+
model: nn.Module,
|
| 30 |
+
loader: DataLoader,
|
| 31 |
+
device: torch.device,
|
| 32 |
+
optimizer: torch.optim.Optimizer | None,
|
| 33 |
+
scaler: GradScaler | None,
|
| 34 |
+
l1: nn.Module,
|
| 35 |
+
grad_accum: int,
|
| 36 |
+
train: bool,
|
| 37 |
+
lambda_uv: float,
|
| 38 |
+
lambda_img: float,
|
| 39 |
+
) -> tuple[float, float]:
|
| 40 |
+
if train:
|
| 41 |
+
model.train()
|
| 42 |
+
else:
|
| 43 |
+
model.eval()
|
| 44 |
+
|
| 45 |
+
tot_uv = 0.0
|
| 46 |
+
tot_img = 0.0
|
| 47 |
+
n = 0
|
| 48 |
+
optimizer_was = optimizer is not None and train
|
| 49 |
+
|
| 50 |
+
with torch.set_grad_enabled(train):
|
| 51 |
+
for step, batch in enumerate(loader):
|
| 52 |
+
warped = batch["warped"].to(device, non_blocking=True)
|
| 53 |
+
uv_gt = batch["uv_gt"].to(device, non_blocking=True)
|
| 54 |
+
unwarped_gt = batch["unwarped"].to(device, non_blocking=True)
|
| 55 |
+
|
| 56 |
+
mean = IMAGENET_MEAN.to(device)
|
| 57 |
+
std = IMAGENET_STD.to(device)
|
| 58 |
+
warped_in = (warped - mean) / std
|
| 59 |
+
|
| 60 |
+
with autocast(enabled=device.type == "cuda"):
|
| 61 |
+
uv_pred = model(warped_in)
|
| 62 |
+
unwarped_pred = grid_sample_unwarp(warped, uv_pred)
|
| 63 |
+
loss_uv = l1(uv_pred, uv_gt)
|
| 64 |
+
loss_img = l1(unwarped_pred, unwarped_gt)
|
| 65 |
+
loss = lambda_uv * loss_uv + lambda_img * loss_img
|
| 66 |
+
loss_scaled = loss / grad_accum
|
| 67 |
+
|
| 68 |
+
if optimizer_was:
|
| 69 |
+
if scaler is not None:
|
| 70 |
+
scaler.scale(loss_scaled).backward()
|
| 71 |
+
else:
|
| 72 |
+
loss_scaled.backward()
|
| 73 |
+
if (step + 1) % grad_accum == 0:
|
| 74 |
+
if scaler is not None:
|
| 75 |
+
scaler.step(optimizer)
|
| 76 |
+
scaler.update()
|
| 77 |
+
else:
|
| 78 |
+
optimizer.step()
|
| 79 |
+
optimizer.zero_grad(set_to_none=True)
|
| 80 |
+
|
| 81 |
+
tot_uv += float(loss_uv.detach())
|
| 82 |
+
tot_img += float(loss_img.detach())
|
| 83 |
+
n += 1
|
| 84 |
+
|
| 85 |
+
return tot_uv / max(1, n), tot_img / max(1, n)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def main() -> None:
|
| 89 |
+
p = argparse.ArgumentParser(description="Train ResNet50+UNet UV baseline on UVDoc")
|
| 90 |
+
p.add_argument("--data_root", type=str, required=True, help="Path to UVDoc final folder (contains img/, grid2d/, ...)")
|
| 91 |
+
p.add_argument("--out_dir", type=str, default="./runs/uvdoc_baseline")
|
| 92 |
+
p.add_argument("--epochs", type=int, default=50)
|
| 93 |
+
p.add_argument("--batch_size", type=int, default=4)
|
| 94 |
+
p.add_argument("--lr", type=float, default=1e-4)
|
| 95 |
+
p.add_argument("--weight_decay", type=float, default=1e-2)
|
| 96 |
+
p.add_argument("--num_workers", type=int, default=4)
|
| 97 |
+
p.add_argument("--grad_accum", type=int, default=1)
|
| 98 |
+
p.add_argument("--lambda_uv", type=float, default=1.0)
|
| 99 |
+
p.add_argument("--lambda_img", type=float, default=1.0)
|
| 100 |
+
p.add_argument("--h", type=int, default=256)
|
| 101 |
+
p.add_argument("--w", type=int, default=256)
|
| 102 |
+
p.add_argument("--no_pretrained", action="store_true")
|
| 103 |
+
p.add_argument("--appearance_aug", nargs="*", default=["visual", "noise", "color"])
|
| 104 |
+
args = p.parse_args()
|
| 105 |
+
|
| 106 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 107 |
+
out = Path(args.out_dir)
|
| 108 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 109 |
+
|
| 110 |
+
train_set = UVDocDenseUVDataset(
|
| 111 |
+
args.data_root,
|
| 112 |
+
split="train",
|
| 113 |
+
out_hw=(args.h, args.w),
|
| 114 |
+
appearance_augmentation=list(args.appearance_aug),
|
| 115 |
+
)
|
| 116 |
+
val_set = UVDocDenseUVDataset(
|
| 117 |
+
args.data_root,
|
| 118 |
+
split="val",
|
| 119 |
+
out_hw=(args.h, args.w),
|
| 120 |
+
appearance_augmentation=[],
|
| 121 |
+
deterministic_crop=True,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
train_loader = DataLoader(
|
| 125 |
+
train_set,
|
| 126 |
+
batch_size=args.batch_size,
|
| 127 |
+
shuffle=True,
|
| 128 |
+
num_workers=args.num_workers,
|
| 129 |
+
pin_memory=device.type == "cuda",
|
| 130 |
+
collate_fn=collate_batch,
|
| 131 |
+
)
|
| 132 |
+
val_loader = DataLoader(
|
| 133 |
+
val_set,
|
| 134 |
+
batch_size=args.batch_size,
|
| 135 |
+
shuffle=False,
|
| 136 |
+
num_workers=args.num_workers,
|
| 137 |
+
pin_memory=device.type == "cuda",
|
| 138 |
+
collate_fn=collate_batch,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
model = ResNet50UNetUV(pretrained=not args.no_pretrained).to(device)
|
| 142 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
| 143 |
+
l1 = nn.L1Loss()
|
| 144 |
+
scaler = GradScaler(enabled=device.type == "cuda")
|
| 145 |
+
|
| 146 |
+
for epoch in range(1, args.epochs + 1):
|
| 147 |
+
tr_uv, tr_img = run_epoch(
|
| 148 |
+
model,
|
| 149 |
+
train_loader,
|
| 150 |
+
device,
|
| 151 |
+
optimizer,
|
| 152 |
+
scaler,
|
| 153 |
+
l1,
|
| 154 |
+
args.grad_accum,
|
| 155 |
+
train=True,
|
| 156 |
+
lambda_uv=args.lambda_uv,
|
| 157 |
+
lambda_img=args.lambda_img,
|
| 158 |
+
)
|
| 159 |
+
va_uv, va_img = run_epoch(
|
| 160 |
+
model,
|
| 161 |
+
val_loader,
|
| 162 |
+
device,
|
| 163 |
+
None,
|
| 164 |
+
None,
|
| 165 |
+
l1,
|
| 166 |
+
1,
|
| 167 |
+
train=False,
|
| 168 |
+
lambda_uv=args.lambda_uv,
|
| 169 |
+
lambda_img=args.lambda_img,
|
| 170 |
+
)
|
| 171 |
+
print(
|
| 172 |
+
f"epoch {epoch:03d} | train L1_uv {tr_uv:.5f} L1_img {tr_img:.5f} | "
|
| 173 |
+
f"val L1_uv {va_uv:.5f} L1_img {va_img:.5f}"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
ckpt = {
|
| 177 |
+
"epoch": epoch,
|
| 178 |
+
"model": model.state_dict(),
|
| 179 |
+
"optimizer": optimizer.state_dict(),
|
| 180 |
+
"args": vars(args),
|
| 181 |
+
}
|
| 182 |
+
torch.save(ckpt, out / "last.pt")
|
| 183 |
+
torch.save(ckpt, out / f"epoch_{epoch:03d}.pt")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
main()
|
baseline_resnet_unet/warp.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def upsample_uv_grid(grid_b2hw: torch.Tensor, out_hw: tuple[int, int]) -> torch.Tensor:
|
| 8 |
+
"""Upsample sparse 2×H×W control grid (UVDoc grid2d) to dense H×W."""
|
| 9 |
+
return F.interpolate(grid_b2hw, size=out_hw, mode="bilinear", align_corners=True)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def grid_sample_unwarp(
|
| 13 |
+
warped_bchw: torch.Tensor,
|
| 14 |
+
grid_b2hw: torch.Tensor,
|
| 15 |
+
align_corners: bool = True,
|
| 16 |
+
) -> torch.Tensor:
|
| 17 |
+
"""
|
| 18 |
+
warped_bchw: B×3×H×W
|
| 19 |
+
grid_b2hw: B×2×H×W — x,y in [-1, 1] for torch.grid_sample (ch0=x, ch1=y).
|
| 20 |
+
"""
|
| 21 |
+
g = grid_b2hw.permute(0, 2, 3, 1).contiguous()
|
| 22 |
+
return F.grid_sample(
|
| 23 |
+
warped_bchw, g, mode="bilinear", padding_mode="border", align_corners=align_corners
|
| 24 |
+
)
|
log_full_uvdoc_gpu0.bak_20260411_122217/nohup.out
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/root/miniconda3/envs/o3dedit/lib/python3.10/site-packages/albumentations/check_version.py:147: UserWarning: Error fetching version info <urlopen error [SSL: UNEXPECTED_EOF_WHILE_READING] EOF occurred in violation of protocol (_ssl.c:1017)>
|
| 2 |
+
data = fetch_version_info()
|
| 3 |
+
/mnt/zsn/zsn_workspace/dzx/UvDoc/UVDoc_official/data_utils.py:51: UserWarning: Argument(s) 'var_limit' are not valid for transform GaussNoise
|
| 4 |
+
A.GaussNoise(var_limit=(10.0, 20.0), p=0.70),
|
| 5 |
+
/root/miniconda3/envs/o3dedit/lib/python3.10/site-packages/albumentations/core/composition.py:331: UserWarning: Got processor for keypoints, but no transform to process it.
|
| 6 |
+
self._set_keys()
|
| 7 |
+
|
| 8 |
+
----- Epoch 0 -----
|
| 9 |
+
|
log_full_uvdoc_gpu0.bak_20260411_122217/params8_lr=0.0002_nepochs50_nepochsdecay20_alpha5.0_beta5.0_gamma=1.0_gammastartep10_datauvdoc.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
--------------- params8_lr=0.0002_nepochs50_nepochsdecay20_alpha5.0_beta5.0_gamma=1.0_gammastartep10_datauvdoc ---------------
|
log_full_uvdoc_gpu0/nohup.out
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
log_full_uvdoc_gpu0/params8_lr=0.0002_nepochs25_nepochsdecay10_alpha5.0_beta5.0_gamma=1.0_gammastartep10_datauvdoc.txt
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
--------------- params8_lr=0.0002_nepochs25_nepochsdecay10_alpha5.0_beta5.0_gamma=1.0_gammastartep10_datauvdoc ---------------
|
| 3 |
+
|
| 4 |
+
Train LRate: 0.0002 Epoch: 1 MSE: 0.02950
|
| 5 |
+
TrainLoss LRate: 0.0002 Epoch: 1 net: 0.31961 L1_g2d: 0.04204 L1_g3d: 0.02189 L1_recon: 0.11414 gamma_w: 0.00000
|
| 6 |
+
Val LRate: 0.0002 Epoch: 1 MSE: 0.01576
|
| 7 |
+
Train LRate: 0.0002 Epoch: 2 MSE: 0.02713
|
| 8 |
+
TrainLoss LRate: 0.0002 Epoch: 2 net: 0.24550 L1_g2d: 0.03096 L1_g3d: 0.01814 L1_recon: 0.10817 gamma_w: 0.00000
|
| 9 |
+
Val LRate: 0.0002 Epoch: 2 MSE: 0.01688
|
| 10 |
+
Train LRate: 0.0002 Epoch: 3 MSE: 0.02606
|
| 11 |
+
TrainLoss LRate: 0.0002 Epoch: 3 net: 0.21935 L1_g2d: 0.02744 L1_g3d: 0.01643 L1_recon: 0.10558 gamma_w: 0.00000
|
| 12 |
+
Val LRate: 0.0002 Epoch: 3 MSE: 0.01589
|
| 13 |
+
Train LRate: 0.0002 Epoch: 4 MSE: 0.02522
|
| 14 |
+
TrainLoss LRate: 0.0002 Epoch: 4 net: 0.19993 L1_g2d: 0.02460 L1_g3d: 0.01539 L1_recon: 0.10323 gamma_w: 0.00000
|
| 15 |
+
Val LRate: 0.0002 Epoch: 4 MSE: 0.01283
|
| 16 |
+
Train LRate: 0.0002 Epoch: 5 MSE: 0.02427
|
| 17 |
+
TrainLoss LRate: 0.0002 Epoch: 5 net: 0.18300 L1_g2d: 0.02206 L1_g3d: 0.01454 L1_recon: 0.10088 gamma_w: 0.00000
|
| 18 |
+
Val LRate: 0.0002 Epoch: 5 MSE: 0.00968
|
| 19 |
+
Train LRate: 0.0002 Epoch: 6 MSE: 0.02299
|
| 20 |
+
TrainLoss LRate: 0.0002 Epoch: 6 net: 0.16705 L1_g2d: 0.01960 L1_g3d: 0.01381 L1_recon: 0.09772 gamma_w: 0.00000
|
| 21 |
+
Val LRate: 0.0002 Epoch: 6 MSE: 0.01058
|
| 22 |
+
Train LRate: 0.0002 Epoch: 7 MSE: 0.02203
|
| 23 |
+
TrainLoss LRate: 0.0002 Epoch: 7 net: 0.15387 L1_g2d: 0.01759 L1_g3d: 0.01318 L1_recon: 0.09528 gamma_w: 0.00000
|
| 24 |
+
Val LRate: 0.0002 Epoch: 7 MSE: 0.01079
|
| 25 |
+
Train LRate: 0.0002 Epoch: 8 MSE: 0.02145
|
| 26 |
+
TrainLoss LRate: 0.0002 Epoch: 8 net: 0.14555 L1_g2d: 0.01636 L1_g3d: 0.01275 L1_recon: 0.09364 gamma_w: 0.00000
|
| 27 |
+
Val LRate: 0.0002 Epoch: 8 MSE: 0.00838
|
| 28 |
+
Train LRate: 0.0002 Epoch: 9 MSE: 0.02079
|
| 29 |
+
TrainLoss LRate: 0.0002 Epoch: 9 net: 0.13673 L1_g2d: 0.01516 L1_g3d: 0.01219 L1_recon: 0.09175 gamma_w: 0.00000
|
| 30 |
+
Val LRate: 0.0002 Epoch: 9 MSE: 0.00880
|
| 31 |
+
Train LRate: 0.0002 Epoch: 10 MSE: 0.02033
|
| 32 |
+
TrainLoss LRate: 0.0002 Epoch: 10 net: 0.13079 L1_g2d: 0.01431 L1_g3d: 0.01185 L1_recon: 0.09071 gamma_w: 0.00000
|
| 33 |
+
Val LRate: 0.0002 Epoch: 10 MSE: 0.00919
|
| 34 |
+
Train LRate: 0.0002 Epoch: 11 MSE: 0.01994
|
| 35 |
+
TrainLoss LRate: 0.0002 Epoch: 11 net: 0.21906 L1_g2d: 0.01408 L1_g3d: 0.01181 L1_recon: 0.08961 gamma_w: 1.00000
|
| 36 |
+
Val LRate: 0.0002 Epoch: 11 MSE: 0.00806
|
| 37 |
+
Train LRate: 0.0002 Epoch: 12 MSE: 0.01935
|
| 38 |
+
TrainLoss LRate: 0.0002 Epoch: 12 net: 0.21137 L1_g2d: 0.01310 L1_g3d: 0.01159 L1_recon: 0.08793 gamma_w: 1.00000
|
| 39 |
+
Val LRate: 0.0002 Epoch: 12 MSE: 0.00672
|
| 40 |
+
Train LRate: 0.0002 Epoch: 13 MSE: 0.01916
|
| 41 |
+
TrainLoss LRate: 0.0002 Epoch: 13 net: 0.20786 L1_g2d: 0.01276 L1_g3d: 0.01135 L1_recon: 0.08733 gamma_w: 1.00000
|
| 42 |
+
Val LRate: 0.0002 Epoch: 13 MSE: 0.00796
|
| 43 |
+
Train LRate: 0.0002 Epoch: 14 MSE: 0.01898
|
| 44 |
+
TrainLoss LRate: 0.0002 Epoch: 14 net: 0.20452 L1_g2d: 0.01237 L1_g3d: 0.01116 L1_recon: 0.08685 gamma_w: 1.00000
|
| 45 |
+
Val LRate: 0.0002 Epoch: 14 MSE: 0.00735
|
| 46 |
+
Train LRate: 0.0002 Epoch: 15 MSE: 0.01891
|
| 47 |
+
TrainLoss LRate: 0.0002 Epoch: 15 net: 0.20277 L1_g2d: 0.01234 L1_g3d: 0.01088 L1_recon: 0.08666 gamma_w: 1.00000
|
| 48 |
+
Val LRate: 0.0002 Epoch: 15 MSE: 0.00655
|
| 49 |
+
Train LRate: 0.0002 Epoch: 16 MSE: 0.01871
|
| 50 |
+
TrainLoss LRate: 0.0002 Epoch: 16 net: 0.19896 L1_g2d: 0.01190 L1_g3d: 0.01068 L1_recon: 0.08606 gamma_w: 1.00000
|
| 51 |
+
Val LRate: 0.0002 Epoch: 16 MSE: 0.00819
|
| 52 |
+
Train LRate: 0.0002 Epoch: 17 MSE: 0.01849
|
| 53 |
+
TrainLoss LRate: 0.0002 Epoch: 17 net: 0.19557 L1_g2d: 0.01158 L1_g3d: 0.01046 L1_recon: 0.08541 gamma_w: 1.00000
|
| 54 |
+
Val LRate: 0.0002 Epoch: 17 MSE: 0.00602
|
| 55 |
+
Train LRate: 0.0002 Epoch: 18 MSE: 0.01825
|
| 56 |
+
TrainLoss LRate: 0.0002 Epoch: 18 net: 0.19157 L1_g2d: 0.01111 L1_g3d: 0.01029 L1_recon: 0.08459 gamma_w: 1.00000
|
| 57 |
+
Val LRate: 0.0002 Epoch: 18 MSE: 0.00702
|
| 58 |
+
Train LRate: 0.0002 Epoch: 19 MSE: 0.01806
|
| 59 |
+
TrainLoss LRate: 0.0002 Epoch: 19 net: 0.18915 L1_g2d: 0.01089 L1_g3d: 0.01013 L1_recon: 0.08407 gamma_w: 1.00000
|
| 60 |
+
Val LRate: 0.0002 Epoch: 19 MSE: 0.00634
|
| 61 |
+
Train LRate: 0.0002 Epoch: 20 MSE: 0.01788
|
| 62 |
+
TrainLoss LRate: 0.0002 Epoch: 20 net: 0.18630 L1_g2d: 0.01060 L1_g3d: 0.00996 L1_recon: 0.08348 gamma_w: 1.00000
|
| 63 |
+
Val LRate: 0.0002 Epoch: 20 MSE: 0.00589
|
| 64 |
+
Train LRate: 0.0002 Epoch: 21 MSE: 0.01795
|
| 65 |
+
TrainLoss LRate: 0.0002 Epoch: 21 net: 0.18629 L1_g2d: 0.01067 L1_g3d: 0.00986 L1_recon: 0.08361 gamma_w: 1.00000
|
| 66 |
+
Val LRate: 0.0002 Epoch: 21 MSE: 0.00598
|
| 67 |
+
Train LRate: 0.0002 Epoch: 22 MSE: 0.01767
|
| 68 |
+
TrainLoss LRate: 0.0002 Epoch: 22 net: 0.18384 L1_g2d: 0.01047 L1_g3d: 0.00970 L1_recon: 0.08299 gamma_w: 1.00000
|
| 69 |
+
Val LRate: 0.0002 Epoch: 22 MSE: 0.00614
|
| 70 |
+
Train LRate: 0.0002 Epoch: 23 MSE: 0.01750
|
| 71 |
+
TrainLoss LRate: 0.0002 Epoch: 23 net: 0.18043 L1_g2d: 0.01009 L1_g3d: 0.00951 L1_recon: 0.08240 gamma_w: 1.00000
|
| 72 |
+
Val LRate: 0.0002 Epoch: 23 MSE: 0.00605
|
| 73 |
+
Train LRate: 0.0002 Epoch: 24 MSE: 0.01747
|
| 74 |
+
TrainLoss LRate: 0.0002 Epoch: 24 net: 0.17951 L1_g2d: 0.01000 L1_g3d: 0.00942 L1_recon: 0.08240 gamma_w: 1.00000
|
| 75 |
+
Val LRate: 0.0002 Epoch: 24 MSE: 0.00744
|
| 76 |
+
Train LRate: 0.0002 Epoch: 25 MSE: 0.01731
|
| 77 |
+
TrainLoss LRate: 0.0002 Epoch: 25 net: 0.17702 L1_g2d: 0.00977 L1_g3d: 0.00927 L1_recon: 0.08180 gamma_w: 1.00000
|
| 78 |
+
Val LRate: 0.0002 Epoch: 25 MSE: 0.00602
|
| 79 |
+
Train LRate: 0.00018181818181818183 Epoch: 26 MSE: 0.01726
|
| 80 |
+
TrainLoss LRate: 0.00018181818181818183 Epoch: 26 net: 0.17589 L1_g2d: 0.00973 L1_g3d: 0.00913 L1_recon: 0.08159 gamma_w: 1.00000
|
| 81 |
+
Val LRate: 0.00018181818181818183 Epoch: 26 MSE: 0.00608
|
| 82 |
+
Train LRate: 0.00016363636363636363 Epoch: 27 MSE: 0.01683
|
| 83 |
+
TrainLoss LRate: 0.00016363636363636363 Epoch: 27 net: 0.17074 L1_g2d: 0.00917 L1_g3d: 0.00892 L1_recon: 0.08029 gamma_w: 1.00000
|
| 84 |
+
Val LRate: 0.00016363636363636363 Epoch: 27 MSE: 0.00621
|
| 85 |
+
Train LRate: 0.00014545454545454546 Epoch: 28 MSE: 0.01659
|
| 86 |
+
TrainLoss LRate: 0.00014545454545454546 Epoch: 28 net: 0.16742 L1_g2d: 0.00890 L1_g3d: 0.00867 L1_recon: 0.07958 gamma_w: 1.00000
|
| 87 |
+
Val LRate: 0.00014545454545454546 Epoch: 28 MSE: 0.00510
|
| 88 |
+
Train LRate: 0.00012727272727272728 Epoch: 29 MSE: 0.01617
|
| 89 |
+
TrainLoss LRate: 0.00012727272727272728 Epoch: 29 net: 0.16291 L1_g2d: 0.00849 L1_g3d: 0.00842 L1_recon: 0.07833 gamma_w: 1.00000
|
| 90 |
+
Val LRate: 0.00012727272727272728 Epoch: 29 MSE: 0.00602
|
| 91 |
+
Train LRate: 0.00010909090909090909 Epoch: 30 MSE: 0.01584
|
| 92 |
+
TrainLoss LRate: 0.00010909090909090909 Epoch: 30 net: 0.15880 L1_g2d: 0.00810 L1_g3d: 0.00822 L1_recon: 0.07720 gamma_w: 1.00000
|
| 93 |
+
Val LRate: 0.00010909090909090909 Epoch: 30 MSE: 0.00571
|
| 94 |
+
Train LRate: 9.090909090909092e-05 Epoch: 31 MSE: 0.01553
|
| 95 |
+
TrainLoss LRate: 9.090909090909092e-05 Epoch: 31 net: 0.15492 L1_g2d: 0.00774 L1_g3d: 0.00800 L1_recon: 0.07619 gamma_w: 1.00000
|
| 96 |
+
Val LRate: 9.090909090909092e-05 Epoch: 31 MSE: 0.00517
|
| 97 |
+
Train LRate: 7.272727272727273e-05 Epoch: 32 MSE: 0.01512
|
| 98 |
+
TrainLoss LRate: 7.272727272727273e-05 Epoch: 32 net: 0.15098 L1_g2d: 0.00742 L1_g3d: 0.00779 L1_recon: 0.07493 gamma_w: 1.00000
|
| 99 |
+
Val LRate: 7.272727272727273e-05 Epoch: 32 MSE: 0.00449
|
| 100 |
+
Train LRate: 5.4545454545454546e-05 Epoch: 33 MSE: 0.01474
|
| 101 |
+
TrainLoss LRate: 5.4545454545454546e-05 Epoch: 33 net: 0.14681 L1_g2d: 0.00704 L1_g3d: 0.00758 L1_recon: 0.07368 gamma_w: 1.00000
|
| 102 |
+
Val LRate: 5.4545454545454546e-05 Epoch: 33 MSE: 0.00450
|
| 103 |
+
Train LRate: 3.636363636363636e-05 Epoch: 34 MSE: 0.01435
|
| 104 |
+
TrainLoss LRate: 3.636363636363636e-05 Epoch: 34 net: 0.14318 L1_g2d: 0.00676 L1_g3d: 0.00740 L1_recon: 0.07237 gamma_w: 1.00000
|
| 105 |
+
Val LRate: 3.636363636363636e-05 Epoch: 34 MSE: 0.00439
|
| 106 |
+
Train LRate: 1.818181818181819e-05 Epoch: 35 MSE: 0.01392
|
| 107 |
+
TrainLoss LRate: 1.818181818181819e-05 Epoch: 35 net: 0.13949 L1_g2d: 0.00648 L1_g3d: 0.00721 L1_recon: 0.07104 gamma_w: 1.00000
|
| 108 |
+
Val LRate: 1.818181818181819e-05 Epoch: 35 MSE: 0.00406
|
| 109 |
+
Train LRate: 0.0 Epoch: 36 MSE: 0.01356
|
| 110 |
+
TrainLoss LRate: 0.0 Epoch: 36 net: 0.13632 L1_g2d: 0.00626 L1_g3d: 0.00705 L1_recon: 0.06977 gamma_w: 1.00000
|
| 111 |
+
Val LRate: 0.0 Epoch: 36 MSE: 0.00392
|
log_full_uvdoc_gpu0/verify_val_ep12_infer/metrics.txt
ADDED
|
@@ -0,0 +1,1001 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
13024 mse=0.013592
|
| 2 |
+
11093 mse=0.004805
|
| 3 |
+
08684 mse=0.010333
|
| 4 |
+
15384 mse=0.019671
|
| 5 |
+
02451 mse=0.006121
|
| 6 |
+
05058 mse=0.004789
|
| 7 |
+
13434 mse=0.003694
|
| 8 |
+
04723 mse=0.001129
|
| 9 |
+
00502 mse=0.013174
|
| 10 |
+
14972 mse=0.003360
|
| 11 |
+
17234 mse=0.019906
|
| 12 |
+
03267 mse=0.005335
|
| 13 |
+
14427 mse=0.005367
|
| 14 |
+
03288 mse=0.002329
|
| 15 |
+
10878 mse=0.009916
|
| 16 |
+
16459 mse=0.000456
|
| 17 |
+
19366 mse=0.002625
|
| 18 |
+
01213 mse=0.005532
|
| 19 |
+
16230 mse=0.010599
|
| 20 |
+
04984 mse=0.006088
|
| 21 |
+
12651 mse=0.002175
|
| 22 |
+
12492 mse=0.000191
|
| 23 |
+
19785 mse=0.005227
|
| 24 |
+
10426 mse=0.016221
|
| 25 |
+
01430 mse=0.005738
|
| 26 |
+
13460 mse=0.003349
|
| 27 |
+
07653 mse=0.018954
|
| 28 |
+
12549 mse=0.001932
|
| 29 |
+
18736 mse=0.004445
|
| 30 |
+
07825 mse=0.012490
|
| 31 |
+
06181 mse=0.003514
|
| 32 |
+
10544 mse=0.006107
|
| 33 |
+
14920 mse=0.012550
|
| 34 |
+
12360 mse=0.000237
|
| 35 |
+
17807 mse=0.002510
|
| 36 |
+
18085 mse=0.002518
|
| 37 |
+
10893 mse=0.004129
|
| 38 |
+
11309 mse=0.007160
|
| 39 |
+
15870 mse=0.001475
|
| 40 |
+
17648 mse=0.001272
|
| 41 |
+
14632 mse=0.006853
|
| 42 |
+
04869 mse=0.004018
|
| 43 |
+
16044 mse=0.010716
|
| 44 |
+
19002 mse=0.004499
|
| 45 |
+
09783 mse=0.004143
|
| 46 |
+
09268 mse=0.007667
|
| 47 |
+
00084 mse=0.009441
|
| 48 |
+
08754 mse=0.001593
|
| 49 |
+
16637 mse=0.004336
|
| 50 |
+
11137 mse=0.020461
|
| 51 |
+
13596 mse=0.003262
|
| 52 |
+
19295 mse=0.022811
|
| 53 |
+
02287 mse=0.006038
|
| 54 |
+
04473 mse=0.002072
|
| 55 |
+
16654 mse=0.003670
|
| 56 |
+
06245 mse=0.003957
|
| 57 |
+
12569 mse=0.003345
|
| 58 |
+
04739 mse=0.001998
|
| 59 |
+
18063 mse=0.002227
|
| 60 |
+
14967 mse=0.003881
|
| 61 |
+
07986 mse=0.001125
|
| 62 |
+
15473 mse=0.007439
|
| 63 |
+
15810 mse=0.006675
|
| 64 |
+
13614 mse=0.003320
|
| 65 |
+
08959 mse=0.003891
|
| 66 |
+
16674 mse=0.008336
|
| 67 |
+
09563 mse=0.006463
|
| 68 |
+
16330 mse=0.003470
|
| 69 |
+
05045 mse=0.003382
|
| 70 |
+
03849 mse=0.011235
|
| 71 |
+
15120 mse=0.002886
|
| 72 |
+
03717 mse=0.002930
|
| 73 |
+
19813 mse=0.004372
|
| 74 |
+
03054 mse=0.006552
|
| 75 |
+
13369 mse=0.013005
|
| 76 |
+
18096 mse=0.008426
|
| 77 |
+
00448 mse=0.004613
|
| 78 |
+
18049 mse=0.004277
|
| 79 |
+
10182 mse=0.008457
|
| 80 |
+
05189 mse=0.006347
|
| 81 |
+
07841 mse=0.012988
|
| 82 |
+
00945 mse=0.005872
|
| 83 |
+
03978 mse=0.003278
|
| 84 |
+
15883 mse=0.004921
|
| 85 |
+
14356 mse=0.007920
|
| 86 |
+
09438 mse=0.003210
|
| 87 |
+
04293 mse=0.000493
|
| 88 |
+
17996 mse=0.005766
|
| 89 |
+
08750 mse=0.000794
|
| 90 |
+
06670 mse=0.008790
|
| 91 |
+
01691 mse=0.003391
|
| 92 |
+
14484 mse=0.009287
|
| 93 |
+
16334 mse=0.011175
|
| 94 |
+
11854 mse=0.011717
|
| 95 |
+
16988 mse=0.005557
|
| 96 |
+
06146 mse=0.006612
|
| 97 |
+
01776 mse=0.005343
|
| 98 |
+
13541 mse=0.003095
|
| 99 |
+
15138 mse=0.011562
|
| 100 |
+
17779 mse=0.004368
|
| 101 |
+
00037 mse=0.009647
|
| 102 |
+
16862 mse=0.000283
|
| 103 |
+
13244 mse=0.008471
|
| 104 |
+
03798 mse=0.012445
|
| 105 |
+
10593 mse=0.004637
|
| 106 |
+
08270 mse=0.006847
|
| 107 |
+
10306 mse=0.001354
|
| 108 |
+
02885 mse=0.006194
|
| 109 |
+
14903 mse=0.003955
|
| 110 |
+
01202 mse=0.004690
|
| 111 |
+
18025 mse=0.009777
|
| 112 |
+
13122 mse=0.007263
|
| 113 |
+
09444 mse=0.003730
|
| 114 |
+
16235 mse=0.007111
|
| 115 |
+
13488 mse=0.002116
|
| 116 |
+
10807 mse=0.006151
|
| 117 |
+
05269 mse=0.010568
|
| 118 |
+
03427 mse=0.014200
|
| 119 |
+
10520 mse=0.015780
|
| 120 |
+
10921 mse=0.002616
|
| 121 |
+
13380 mse=0.007439
|
| 122 |
+
11401 mse=0.005065
|
| 123 |
+
15401 mse=0.005917
|
| 124 |
+
02989 mse=0.020380
|
| 125 |
+
16112 mse=0.014466
|
| 126 |
+
14283 mse=0.010622
|
| 127 |
+
05947 mse=0.010291
|
| 128 |
+
09114 mse=0.020833
|
| 129 |
+
14490 mse=0.001560
|
| 130 |
+
17005 mse=0.005340
|
| 131 |
+
15417 mse=0.021213
|
| 132 |
+
05740 mse=0.002117
|
| 133 |
+
01318 mse=0.005710
|
| 134 |
+
18882 mse=0.002263
|
| 135 |
+
08831 mse=0.000438
|
| 136 |
+
00679 mse=0.003458
|
| 137 |
+
18848 mse=0.004376
|
| 138 |
+
19801 mse=0.006360
|
| 139 |
+
02992 mse=0.011109
|
| 140 |
+
00847 mse=0.004666
|
| 141 |
+
08204 mse=0.001943
|
| 142 |
+
19315 mse=0.014072
|
| 143 |
+
14156 mse=0.004043
|
| 144 |
+
01306 mse=0.005419
|
| 145 |
+
02635 mse=0.009990
|
| 146 |
+
14344 mse=0.003637
|
| 147 |
+
16180 mse=0.004619
|
| 148 |
+
17694 mse=0.006835
|
| 149 |
+
16399 mse=0.000178
|
| 150 |
+
08920 mse=0.005907
|
| 151 |
+
13189 mse=0.009695
|
| 152 |
+
09964 mse=0.005229
|
| 153 |
+
15358 mse=0.010141
|
| 154 |
+
04074 mse=0.003589
|
| 155 |
+
14390 mse=0.003879
|
| 156 |
+
09457 mse=0.005501
|
| 157 |
+
18651 mse=0.002579
|
| 158 |
+
10079 mse=0.004674
|
| 159 |
+
05436 mse=0.002362
|
| 160 |
+
01954 mse=0.005140
|
| 161 |
+
19929 mse=0.002276
|
| 162 |
+
04660 mse=0.001914
|
| 163 |
+
08704 mse=0.001942
|
| 164 |
+
04882 mse=0.004099
|
| 165 |
+
07366 mse=0.017801
|
| 166 |
+
17844 mse=0.004283
|
| 167 |
+
03760 mse=0.007480
|
| 168 |
+
13930 mse=0.004852
|
| 169 |
+
06482 mse=0.008570
|
| 170 |
+
08668 mse=0.005134
|
| 171 |
+
07648 mse=0.003510
|
| 172 |
+
02332 mse=0.006791
|
| 173 |
+
14877 mse=0.004230
|
| 174 |
+
09264 mse=0.004825
|
| 175 |
+
09440 mse=0.001539
|
| 176 |
+
07585 mse=0.006412
|
| 177 |
+
01069 mse=0.006695
|
| 178 |
+
09543 mse=0.005231
|
| 179 |
+
15398 mse=0.005820
|
| 180 |
+
19199 mse=0.009465
|
| 181 |
+
13386 mse=0.004845
|
| 182 |
+
00087 mse=0.003284
|
| 183 |
+
05129 mse=0.009161
|
| 184 |
+
02459 mse=0.007267
|
| 185 |
+
18087 mse=0.012211
|
| 186 |
+
05695 mse=0.002230
|
| 187 |
+
07859 mse=0.004539
|
| 188 |
+
05146 mse=0.007429
|
| 189 |
+
02660 mse=0.004905
|
| 190 |
+
08671 mse=0.005280
|
| 191 |
+
06914 mse=0.003935
|
| 192 |
+
06808 mse=0.002815
|
| 193 |
+
03251 mse=0.003065
|
| 194 |
+
01973 mse=0.008081
|
| 195 |
+
18632 mse=0.009218
|
| 196 |
+
10006 mse=0.004221
|
| 197 |
+
14051 mse=0.001517
|
| 198 |
+
09743 mse=0.005019
|
| 199 |
+
14507 mse=0.006649
|
| 200 |
+
14762 mse=0.002845
|
| 201 |
+
19426 mse=0.008785
|
| 202 |
+
17123 mse=0.009591
|
| 203 |
+
07395 mse=0.005674
|
| 204 |
+
13637 mse=0.001504
|
| 205 |
+
18286 mse=0.027338
|
| 206 |
+
03881 mse=0.002515
|
| 207 |
+
08132 mse=0.019447
|
| 208 |
+
02791 mse=0.001741
|
| 209 |
+
09956 mse=0.007228
|
| 210 |
+
18443 mse=0.012852
|
| 211 |
+
04888 mse=0.005656
|
| 212 |
+
04971 mse=0.009306
|
| 213 |
+
01756 mse=0.007460
|
| 214 |
+
18907 mse=0.003296
|
| 215 |
+
02427 mse=0.015776
|
| 216 |
+
16113 mse=0.008090
|
| 217 |
+
13130 mse=0.004492
|
| 218 |
+
02141 mse=0.006973
|
| 219 |
+
03738 mse=0.006366
|
| 220 |
+
15654 mse=0.006128
|
| 221 |
+
08273 mse=0.020160
|
| 222 |
+
01683 mse=0.001660
|
| 223 |
+
06041 mse=0.005015
|
| 224 |
+
11164 mse=0.014999
|
| 225 |
+
14433 mse=0.001843
|
| 226 |
+
15686 mse=0.004520
|
| 227 |
+
15381 mse=0.005756
|
| 228 |
+
09317 mse=0.009119
|
| 229 |
+
18810 mse=0.007605
|
| 230 |
+
00401 mse=0.000677
|
| 231 |
+
03361 mse=0.007881
|
| 232 |
+
16085 mse=0.008975
|
| 233 |
+
04137 mse=0.007379
|
| 234 |
+
09572 mse=0.007390
|
| 235 |
+
18132 mse=0.004506
|
| 236 |
+
01786 mse=0.006291
|
| 237 |
+
01491 mse=0.003085
|
| 238 |
+
08988 mse=0.013923
|
| 239 |
+
04147 mse=0.005262
|
| 240 |
+
19404 mse=0.005139
|
| 241 |
+
00463 mse=0.006676
|
| 242 |
+
09904 mse=0.014308
|
| 243 |
+
05879 mse=0.003520
|
| 244 |
+
11825 mse=0.006539
|
| 245 |
+
07454 mse=0.008526
|
| 246 |
+
09707 mse=0.000577
|
| 247 |
+
06360 mse=0.005060
|
| 248 |
+
03294 mse=0.008094
|
| 249 |
+
09273 mse=0.023240
|
| 250 |
+
17191 mse=0.003251
|
| 251 |
+
09060 mse=0.003096
|
| 252 |
+
15866 mse=0.007664
|
| 253 |
+
07090 mse=0.007650
|
| 254 |
+
06276 mse=0.004397
|
| 255 |
+
06066 mse=0.006344
|
| 256 |
+
17559 mse=0.002536
|
| 257 |
+
03917 mse=0.002943
|
| 258 |
+
10339 mse=0.008907
|
| 259 |
+
06310 mse=0.004446
|
| 260 |
+
03953 mse=0.002108
|
| 261 |
+
07554 mse=0.002860
|
| 262 |
+
08237 mse=0.007305
|
| 263 |
+
10047 mse=0.014969
|
| 264 |
+
08877 mse=0.003393
|
| 265 |
+
14868 mse=0.006121
|
| 266 |
+
11528 mse=0.002175
|
| 267 |
+
10839 mse=0.003760
|
| 268 |
+
18035 mse=0.002952
|
| 269 |
+
13942 mse=0.005696
|
| 270 |
+
11264 mse=0.005236
|
| 271 |
+
17397 mse=0.012174
|
| 272 |
+
18136 mse=0.004356
|
| 273 |
+
15507 mse=0.009321
|
| 274 |
+
12092 mse=0.015843
|
| 275 |
+
19056 mse=0.004750
|
| 276 |
+
10033 mse=0.002007
|
| 277 |
+
07972 mse=0.003850
|
| 278 |
+
13251 mse=0.002938
|
| 279 |
+
07261 mse=0.004277
|
| 280 |
+
09641 mse=0.004101
|
| 281 |
+
03049 mse=0.005762
|
| 282 |
+
00565 mse=0.007996
|
| 283 |
+
14685 mse=0.002584
|
| 284 |
+
16083 mse=0.013000
|
| 285 |
+
15600 mse=0.005931
|
| 286 |
+
07874 mse=0.012112
|
| 287 |
+
18192 mse=0.002150
|
| 288 |
+
16009 mse=0.001184
|
| 289 |
+
19748 mse=0.015839
|
| 290 |
+
07908 mse=0.005978
|
| 291 |
+
02805 mse=0.003082
|
| 292 |
+
06260 mse=0.005192
|
| 293 |
+
16928 mse=0.000296
|
| 294 |
+
00332 mse=0.000355
|
| 295 |
+
18212 mse=0.004046
|
| 296 |
+
09052 mse=0.008125
|
| 297 |
+
09163 mse=0.023424
|
| 298 |
+
11152 mse=0.004150
|
| 299 |
+
08465 mse=0.000505
|
| 300 |
+
11626 mse=0.004183
|
| 301 |
+
06950 mse=0.003077
|
| 302 |
+
15976 mse=0.003224
|
| 303 |
+
05958 mse=0.007775
|
| 304 |
+
10650 mse=0.003275
|
| 305 |
+
16197 mse=0.008717
|
| 306 |
+
00940 mse=0.009285
|
| 307 |
+
11023 mse=0.003402
|
| 308 |
+
19741 mse=0.003408
|
| 309 |
+
18672 mse=0.006676
|
| 310 |
+
15227 mse=0.003110
|
| 311 |
+
07990 mse=0.009484
|
| 312 |
+
14771 mse=0.002472
|
| 313 |
+
09139 mse=0.006769
|
| 314 |
+
02839 mse=0.002735
|
| 315 |
+
08044 mse=0.008896
|
| 316 |
+
10561 mse=0.004395
|
| 317 |
+
08794 mse=0.000091
|
| 318 |
+
14272 mse=0.003562
|
| 319 |
+
15232 mse=0.005341
|
| 320 |
+
15843 mse=0.003702
|
| 321 |
+
17634 mse=0.004786
|
| 322 |
+
08850 mse=0.000439
|
| 323 |
+
17096 mse=0.001433
|
| 324 |
+
05771 mse=0.008529
|
| 325 |
+
14270 mse=0.002307
|
| 326 |
+
05013 mse=0.003193
|
| 327 |
+
15064 mse=0.008857
|
| 328 |
+
11868 mse=0.001827
|
| 329 |
+
14205 mse=0.011431
|
| 330 |
+
03458 mse=0.002995
|
| 331 |
+
16062 mse=0.006256
|
| 332 |
+
13764 mse=0.011470
|
| 333 |
+
00414 mse=0.000423
|
| 334 |
+
11181 mse=0.002202
|
| 335 |
+
16579 mse=0.000423
|
| 336 |
+
02259 mse=0.004367
|
| 337 |
+
12142 mse=0.014027
|
| 338 |
+
14037 mse=0.005371
|
| 339 |
+
17456 mse=0.006111
|
| 340 |
+
11377 mse=0.006735
|
| 341 |
+
01345 mse=0.002136
|
| 342 |
+
18544 mse=0.005388
|
| 343 |
+
03966 mse=0.004297
|
| 344 |
+
09882 mse=0.004373
|
| 345 |
+
03339 mse=0.012594
|
| 346 |
+
07957 mse=0.006317
|
| 347 |
+
02490 mse=0.010495
|
| 348 |
+
06498 mse=0.005833
|
| 349 |
+
12997 mse=0.007122
|
| 350 |
+
12298 mse=0.009941
|
| 351 |
+
18867 mse=0.009180
|
| 352 |
+
10222 mse=0.009552
|
| 353 |
+
00841 mse=0.009046
|
| 354 |
+
18962 mse=0.005539
|
| 355 |
+
19596 mse=0.001313
|
| 356 |
+
10192 mse=0.005371
|
| 357 |
+
05366 mse=0.002582
|
| 358 |
+
19872 mse=0.006050
|
| 359 |
+
05498 mse=0.002254
|
| 360 |
+
12087 mse=0.005397
|
| 361 |
+
08396 mse=0.000558
|
| 362 |
+
15238 mse=0.010978
|
| 363 |
+
07142 mse=0.005340
|
| 364 |
+
18480 mse=0.002215
|
| 365 |
+
03748 mse=0.004140
|
| 366 |
+
15517 mse=0.009327
|
| 367 |
+
04136 mse=0.012967
|
| 368 |
+
07816 mse=0.005544
|
| 369 |
+
04993 mse=0.005405
|
| 370 |
+
00667 mse=0.001161
|
| 371 |
+
06532 mse=0.012035
|
| 372 |
+
08102 mse=0.003827
|
| 373 |
+
01016 mse=0.011339
|
| 374 |
+
01319 mse=0.003026
|
| 375 |
+
02248 mse=0.003955
|
| 376 |
+
07347 mse=0.005337
|
| 377 |
+
11467 mse=0.006970
|
| 378 |
+
01706 mse=0.004633
|
| 379 |
+
17675 mse=0.005402
|
| 380 |
+
00635 mse=0.002192
|
| 381 |
+
15411 mse=0.005838
|
| 382 |
+
02683 mse=0.007702
|
| 383 |
+
08266 mse=0.007328
|
| 384 |
+
13813 mse=0.006184
|
| 385 |
+
09117 mse=0.017622
|
| 386 |
+
19531 mse=0.004532
|
| 387 |
+
11059 mse=0.010098
|
| 388 |
+
10710 mse=0.004675
|
| 389 |
+
14899 mse=0.002513
|
| 390 |
+
06935 mse=0.012035
|
| 391 |
+
10726 mse=0.003698
|
| 392 |
+
12422 mse=0.000294
|
| 393 |
+
08524 mse=0.007463
|
| 394 |
+
01647 mse=0.009652
|
| 395 |
+
05954 mse=0.003376
|
| 396 |
+
15128 mse=0.006336
|
| 397 |
+
04446 mse=0.000312
|
| 398 |
+
14044 mse=0.007312
|
| 399 |
+
02750 mse=0.010362
|
| 400 |
+
00888 mse=0.017365
|
| 401 |
+
18493 mse=0.004126
|
| 402 |
+
19395 mse=0.005806
|
| 403 |
+
12911 mse=0.003836
|
| 404 |
+
00883 mse=0.001296
|
| 405 |
+
19927 mse=0.003950
|
| 406 |
+
04182 mse=0.002908
|
| 407 |
+
15315 mse=0.009212
|
| 408 |
+
05408 mse=0.004999
|
| 409 |
+
13248 mse=0.005942
|
| 410 |
+
11014 mse=0.011532
|
| 411 |
+
12742 mse=0.000877
|
| 412 |
+
00951 mse=0.008686
|
| 413 |
+
19747 mse=0.003179
|
| 414 |
+
14281 mse=0.010931
|
| 415 |
+
08943 mse=0.006028
|
| 416 |
+
07190 mse=0.009106
|
| 417 |
+
15998 mse=0.003195
|
| 418 |
+
07226 mse=0.004107
|
| 419 |
+
17895 mse=0.009976
|
| 420 |
+
17636 mse=0.004723
|
| 421 |
+
13730 mse=0.002176
|
| 422 |
+
12779 mse=0.000939
|
| 423 |
+
09784 mse=0.012018
|
| 424 |
+
11526 mse=0.002732
|
| 425 |
+
00193 mse=0.003580
|
| 426 |
+
15629 mse=0.003651
|
| 427 |
+
12464 mse=0.000513
|
| 428 |
+
18864 mse=0.010520
|
| 429 |
+
00640 mse=0.001982
|
| 430 |
+
19139 mse=0.010358
|
| 431 |
+
12762 mse=0.001432
|
| 432 |
+
13767 mse=0.006102
|
| 433 |
+
06777 mse=0.004732
|
| 434 |
+
02040 mse=0.011819
|
| 435 |
+
13624 mse=0.003338
|
| 436 |
+
15267 mse=0.015659
|
| 437 |
+
17780 mse=0.016504
|
| 438 |
+
10842 mse=0.001999
|
| 439 |
+
13356 mse=0.009999
|
| 440 |
+
13594 mse=0.004251
|
| 441 |
+
02104 mse=0.007057
|
| 442 |
+
07010 mse=0.004739
|
| 443 |
+
05007 mse=0.002228
|
| 444 |
+
12551 mse=0.001664
|
| 445 |
+
07577 mse=0.010430
|
| 446 |
+
04942 mse=0.006112
|
| 447 |
+
04534 mse=0.002691
|
| 448 |
+
02640 mse=0.003778
|
| 449 |
+
17699 mse=0.006536
|
| 450 |
+
10476 mse=0.010439
|
| 451 |
+
11005 mse=0.006676
|
| 452 |
+
08903 mse=0.007070
|
| 453 |
+
07683 mse=0.009860
|
| 454 |
+
09362 mse=0.002731
|
| 455 |
+
02538 mse=0.003612
|
| 456 |
+
14413 mse=0.002838
|
| 457 |
+
08998 mse=0.007377
|
| 458 |
+
07839 mse=0.004246
|
| 459 |
+
15879 mse=0.007475
|
| 460 |
+
17081 mse=0.005933
|
| 461 |
+
09086 mse=0.003096
|
| 462 |
+
08101 mse=0.005517
|
| 463 |
+
08493 mse=0.011169
|
| 464 |
+
14747 mse=0.007154
|
| 465 |
+
15554 mse=0.004717
|
| 466 |
+
05215 mse=0.006625
|
| 467 |
+
14611 mse=0.000665
|
| 468 |
+
17952 mse=0.004754
|
| 469 |
+
13983 mse=0.004664
|
| 470 |
+
16540 mse=0.000416
|
| 471 |
+
14502 mse=0.012893
|
| 472 |
+
10397 mse=0.002089
|
| 473 |
+
18314 mse=0.007516
|
| 474 |
+
17404 mse=0.007627
|
| 475 |
+
15225 mse=0.001843
|
| 476 |
+
04369 mse=0.000670
|
| 477 |
+
14957 mse=0.004946
|
| 478 |
+
03976 mse=0.003730
|
| 479 |
+
08165 mse=0.007345
|
| 480 |
+
18265 mse=0.014817
|
| 481 |
+
01645 mse=0.006865
|
| 482 |
+
15222 mse=0.003725
|
| 483 |
+
16972 mse=0.005263
|
| 484 |
+
05763 mse=0.010498
|
| 485 |
+
07170 mse=0.004263
|
| 486 |
+
13931 mse=0.006381
|
| 487 |
+
03492 mse=0.004798
|
| 488 |
+
00181 mse=0.004791
|
| 489 |
+
04835 mse=0.000131
|
| 490 |
+
07994 mse=0.005594
|
| 491 |
+
13095 mse=0.002430
|
| 492 |
+
16199 mse=0.018779
|
| 493 |
+
12582 mse=0.001651
|
| 494 |
+
06371 mse=0.004494
|
| 495 |
+
18863 mse=0.014359
|
| 496 |
+
13580 mse=0.001739
|
| 497 |
+
14923 mse=0.006933
|
| 498 |
+
02386 mse=0.009795
|
| 499 |
+
15569 mse=0.003373
|
| 500 |
+
08023 mse=0.010031
|
| 501 |
+
01514 mse=0.002425
|
| 502 |
+
00800 mse=0.000258
|
| 503 |
+
04828 mse=0.000532
|
| 504 |
+
06525 mse=0.010423
|
| 505 |
+
07361 mse=0.010910
|
| 506 |
+
10171 mse=0.006608
|
| 507 |
+
07696 mse=0.007901
|
| 508 |
+
03060 mse=0.002002
|
| 509 |
+
10983 mse=0.004688
|
| 510 |
+
16890 mse=0.000462
|
| 511 |
+
09299 mse=0.004427
|
| 512 |
+
02778 mse=0.004951
|
| 513 |
+
05560 mse=0.009737
|
| 514 |
+
15505 mse=0.011543
|
| 515 |
+
16750 mse=0.007996
|
| 516 |
+
07002 mse=0.004323
|
| 517 |
+
14488 mse=0.005948
|
| 518 |
+
14476 mse=0.005617
|
| 519 |
+
15236 mse=0.010391
|
| 520 |
+
10559 mse=0.008287
|
| 521 |
+
19004 mse=0.002443
|
| 522 |
+
14086 mse=0.015066
|
| 523 |
+
06887 mse=0.005182
|
| 524 |
+
09401 mse=0.002061
|
| 525 |
+
09957 mse=0.008144
|
| 526 |
+
00013 mse=0.005464
|
| 527 |
+
17954 mse=0.010544
|
| 528 |
+
13306 mse=0.040973
|
| 529 |
+
09861 mse=0.015161
|
| 530 |
+
18648 mse=0.009899
|
| 531 |
+
05702 mse=0.004912
|
| 532 |
+
12423 mse=0.000515
|
| 533 |
+
13777 mse=0.004517
|
| 534 |
+
06286 mse=0.006292
|
| 535 |
+
04170 mse=0.012163
|
| 536 |
+
18166 mse=0.014863
|
| 537 |
+
09688 mse=0.004297
|
| 538 |
+
13185 mse=0.007625
|
| 539 |
+
10688 mse=0.016156
|
| 540 |
+
13382 mse=0.023912
|
| 541 |
+
10134 mse=0.004477
|
| 542 |
+
16617 mse=0.001585
|
| 543 |
+
09840 mse=0.010508
|
| 544 |
+
04083 mse=0.006149
|
| 545 |
+
10286 mse=0.012526
|
| 546 |
+
10819 mse=0.011054
|
| 547 |
+
02292 mse=0.006911
|
| 548 |
+
14132 mse=0.004875
|
| 549 |
+
11934 mse=0.003254
|
| 550 |
+
06456 mse=0.019909
|
| 551 |
+
17643 mse=0.002607
|
| 552 |
+
17036 mse=0.006798
|
| 553 |
+
00055 mse=0.005215
|
| 554 |
+
14288 mse=0.014392
|
| 555 |
+
01454 mse=0.006386
|
| 556 |
+
08346 mse=0.000335
|
| 557 |
+
06228 mse=0.005780
|
| 558 |
+
18893 mse=0.005317
|
| 559 |
+
12622 mse=0.014434
|
| 560 |
+
03789 mse=0.003452
|
| 561 |
+
16758 mse=0.001055
|
| 562 |
+
14299 mse=0.021217
|
| 563 |
+
10278 mse=0.007585
|
| 564 |
+
11327 mse=0.004929
|
| 565 |
+
14239 mse=0.003540
|
| 566 |
+
19548 mse=0.009048
|
| 567 |
+
03552 mse=0.022578
|
| 568 |
+
01253 mse=0.005894
|
| 569 |
+
08698 mse=0.000867
|
| 570 |
+
19024 mse=0.002185
|
| 571 |
+
05851 mse=0.008123
|
| 572 |
+
08559 mse=0.005700
|
| 573 |
+
03779 mse=0.007637
|
| 574 |
+
00904 mse=0.008019
|
| 575 |
+
10851 mse=0.003364
|
| 576 |
+
17570 mse=0.002191
|
| 577 |
+
13096 mse=0.005568
|
| 578 |
+
16692 mse=0.002952
|
| 579 |
+
11506 mse=0.003977
|
| 580 |
+
09146 mse=0.008194
|
| 581 |
+
02274 mse=0.004023
|
| 582 |
+
09129 mse=0.013056
|
| 583 |
+
10756 mse=0.008751
|
| 584 |
+
13056 mse=0.015209
|
| 585 |
+
06328 mse=0.006939
|
| 586 |
+
00775 mse=0.000719
|
| 587 |
+
07304 mse=0.005530
|
| 588 |
+
07457 mse=0.005769
|
| 589 |
+
10000 mse=0.002826
|
| 590 |
+
11457 mse=0.005553
|
| 591 |
+
15083 mse=0.003962
|
| 592 |
+
06539 mse=0.003780
|
| 593 |
+
07288 mse=0.015032
|
| 594 |
+
15423 mse=0.010442
|
| 595 |
+
01268 mse=0.010835
|
| 596 |
+
12535 mse=0.007065
|
| 597 |
+
03542 mse=0.005691
|
| 598 |
+
05216 mse=0.008257
|
| 599 |
+
08742 mse=0.001468
|
| 600 |
+
08131 mse=0.011416
|
| 601 |
+
13491 mse=0.009239
|
| 602 |
+
10885 mse=0.002555
|
| 603 |
+
05877 mse=0.005266
|
| 604 |
+
00812 mse=0.003922
|
| 605 |
+
13510 mse=0.007523
|
| 606 |
+
05801 mse=0.013077
|
| 607 |
+
05324 mse=0.009311
|
| 608 |
+
07757 mse=0.002630
|
| 609 |
+
05064 mse=0.003532
|
| 610 |
+
13316 mse=0.008569
|
| 611 |
+
18346 mse=0.004035
|
| 612 |
+
11589 mse=0.010977
|
| 613 |
+
03368 mse=0.010653
|
| 614 |
+
08176 mse=0.004915
|
| 615 |
+
06883 mse=0.012554
|
| 616 |
+
11724 mse=0.010935
|
| 617 |
+
01307 mse=0.000943
|
| 618 |
+
11948 mse=0.005417
|
| 619 |
+
10101 mse=0.003065
|
| 620 |
+
01370 mse=0.001716
|
| 621 |
+
04176 mse=0.001732
|
| 622 |
+
14083 mse=0.004617
|
| 623 |
+
04852 mse=0.012606
|
| 624 |
+
18105 mse=0.004661
|
| 625 |
+
19087 mse=0.002899
|
| 626 |
+
12098 mse=0.002071
|
| 627 |
+
01180 mse=0.032120
|
| 628 |
+
17877 mse=0.010336
|
| 629 |
+
04884 mse=0.003506
|
| 630 |
+
02465 mse=0.008399
|
| 631 |
+
19844 mse=0.005397
|
| 632 |
+
00316 mse=0.000524
|
| 633 |
+
18379 mse=0.007948
|
| 634 |
+
14014 mse=0.005911
|
| 635 |
+
18077 mse=0.006781
|
| 636 |
+
14478 mse=0.004664
|
| 637 |
+
05294 mse=0.021503
|
| 638 |
+
08583 mse=0.002269
|
| 639 |
+
04286 mse=0.000144
|
| 640 |
+
10929 mse=0.005501
|
| 641 |
+
00116 mse=0.015530
|
| 642 |
+
01444 mse=0.004305
|
| 643 |
+
09066 mse=0.004321
|
| 644 |
+
19778 mse=0.003098
|
| 645 |
+
03024 mse=0.010053
|
| 646 |
+
01664 mse=0.007941
|
| 647 |
+
08228 mse=0.004464
|
| 648 |
+
16008 mse=0.003189
|
| 649 |
+
16561 mse=0.003868
|
| 650 |
+
08650 mse=0.002290
|
| 651 |
+
06671 mse=0.003400
|
| 652 |
+
11235 mse=0.010489
|
| 653 |
+
06901 mse=0.009372
|
| 654 |
+
09232 mse=0.004964
|
| 655 |
+
08924 mse=0.021215
|
| 656 |
+
05093 mse=0.005350
|
| 657 |
+
19874 mse=0.009032
|
| 658 |
+
03507 mse=0.003076
|
| 659 |
+
03782 mse=0.022239
|
| 660 |
+
08666 mse=0.003655
|
| 661 |
+
04400 mse=0.000548
|
| 662 |
+
03394 mse=0.002589
|
| 663 |
+
09810 mse=0.012773
|
| 664 |
+
18173 mse=0.003096
|
| 665 |
+
00256 mse=0.000407
|
| 666 |
+
17332 mse=0.013669
|
| 667 |
+
09912 mse=0.005981
|
| 668 |
+
17800 mse=0.002989
|
| 669 |
+
14358 mse=0.007492
|
| 670 |
+
05169 mse=0.015356
|
| 671 |
+
09338 mse=0.008593
|
| 672 |
+
12108 mse=0.004350
|
| 673 |
+
08004 mse=0.005980
|
| 674 |
+
02254 mse=0.010440
|
| 675 |
+
11436 mse=0.006910
|
| 676 |
+
04340 mse=0.000543
|
| 677 |
+
19738 mse=0.006448
|
| 678 |
+
16576 mse=0.004364
|
| 679 |
+
06984 mse=0.003509
|
| 680 |
+
17617 mse=0.004963
|
| 681 |
+
02400 mse=0.015990
|
| 682 |
+
03276 mse=0.005885
|
| 683 |
+
18448 mse=0.008356
|
| 684 |
+
15017 mse=0.003819
|
| 685 |
+
00305 mse=0.000113
|
| 686 |
+
02377 mse=0.011819
|
| 687 |
+
10360 mse=0.005368
|
| 688 |
+
14982 mse=0.011285
|
| 689 |
+
09830 mse=0.006814
|
| 690 |
+
04288 mse=0.000242
|
| 691 |
+
12969 mse=0.005355
|
| 692 |
+
08703 mse=0.000688
|
| 693 |
+
07821 mse=0.005969
|
| 694 |
+
10295 mse=0.003671
|
| 695 |
+
06693 mse=0.004699
|
| 696 |
+
08544 mse=0.005174
|
| 697 |
+
10366 mse=0.004704
|
| 698 |
+
17130 mse=0.003539
|
| 699 |
+
18521 mse=0.002302
|
| 700 |
+
19125 mse=0.003507
|
| 701 |
+
13737 mse=0.003372
|
| 702 |
+
02686 mse=0.003147
|
| 703 |
+
01302 mse=0.005836
|
| 704 |
+
19481 mse=0.005929
|
| 705 |
+
18970 mse=0.009271
|
| 706 |
+
08067 mse=0.002808
|
| 707 |
+
18667 mse=0.007331
|
| 708 |
+
03928 mse=0.028819
|
| 709 |
+
13230 mse=0.004064
|
| 710 |
+
07707 mse=0.005022
|
| 711 |
+
02226 mse=0.012873
|
| 712 |
+
19498 mse=0.004052
|
| 713 |
+
02245 mse=0.006626
|
| 714 |
+
06089 mse=0.002865
|
| 715 |
+
02625 mse=0.008537
|
| 716 |
+
16640 mse=0.006780
|
| 717 |
+
01863 mse=0.015235
|
| 718 |
+
05158 mse=0.005529
|
| 719 |
+
17403 mse=0.012846
|
| 720 |
+
16477 mse=0.000135
|
| 721 |
+
15623 mse=0.016849
|
| 722 |
+
19142 mse=0.006173
|
| 723 |
+
01643 mse=0.003417
|
| 724 |
+
01873 mse=0.006018
|
| 725 |
+
10276 mse=0.010539
|
| 726 |
+
01997 mse=0.012274
|
| 727 |
+
17766 mse=0.002147
|
| 728 |
+
18978 mse=0.003328
|
| 729 |
+
01916 mse=0.006207
|
| 730 |
+
07133 mse=0.010933
|
| 731 |
+
09722 mse=0.008000
|
| 732 |
+
06222 mse=0.003645
|
| 733 |
+
05072 mse=0.002408
|
| 734 |
+
15947 mse=0.002638
|
| 735 |
+
18211 mse=0.003234
|
| 736 |
+
13861 mse=0.018975
|
| 737 |
+
09347 mse=0.007893
|
| 738 |
+
14909 mse=0.009625
|
| 739 |
+
08690 mse=0.001223
|
| 740 |
+
12793 mse=0.001838
|
| 741 |
+
00070 mse=0.019099
|
| 742 |
+
12418 mse=0.000195
|
| 743 |
+
05394 mse=0.008171
|
| 744 |
+
01921 mse=0.004962
|
| 745 |
+
13141 mse=0.003647
|
| 746 |
+
07004 mse=0.004832
|
| 747 |
+
15773 mse=0.003870
|
| 748 |
+
15913 mse=0.005211
|
| 749 |
+
13317 mse=0.002295
|
| 750 |
+
05449 mse=0.003415
|
| 751 |
+
07745 mse=0.001731
|
| 752 |
+
03056 mse=0.020637
|
| 753 |
+
00483 mse=0.003349
|
| 754 |
+
17713 mse=0.004062
|
| 755 |
+
01657 mse=0.011674
|
| 756 |
+
03208 mse=0.007049
|
| 757 |
+
18033 mse=0.011215
|
| 758 |
+
14520 mse=0.002908
|
| 759 |
+
02470 mse=0.007919
|
| 760 |
+
08185 mse=0.003485
|
| 761 |
+
15159 mse=0.003266
|
| 762 |
+
09127 mse=0.004895
|
| 763 |
+
06012 mse=0.002327
|
| 764 |
+
13824 mse=0.007023
|
| 765 |
+
04593 mse=0.008154
|
| 766 |
+
14700 mse=0.015028
|
| 767 |
+
17573 mse=0.011038
|
| 768 |
+
06232 mse=0.004607
|
| 769 |
+
06278 mse=0.005450
|
| 770 |
+
08147 mse=0.004964
|
| 771 |
+
03580 mse=0.005904
|
| 772 |
+
11118 mse=0.004537
|
| 773 |
+
13193 mse=0.011521
|
| 774 |
+
01986 mse=0.008747
|
| 775 |
+
03224 mse=0.001771
|
| 776 |
+
01775 mse=0.006740
|
| 777 |
+
15303 mse=0.009594
|
| 778 |
+
13470 mse=0.003670
|
| 779 |
+
13879 mse=0.007771
|
| 780 |
+
11609 mse=0.009756
|
| 781 |
+
14124 mse=0.003176
|
| 782 |
+
03176 mse=0.010371
|
| 783 |
+
03090 mse=0.010388
|
| 784 |
+
06239 mse=0.015475
|
| 785 |
+
13338 mse=0.006884
|
| 786 |
+
15498 mse=0.002574
|
| 787 |
+
07962 mse=0.007727
|
| 788 |
+
15488 mse=0.003228
|
| 789 |
+
18881 mse=0.009415
|
| 790 |
+
18711 mse=0.002872
|
| 791 |
+
04335 mse=0.000312
|
| 792 |
+
17669 mse=0.001634
|
| 793 |
+
07020 mse=0.003957
|
| 794 |
+
15906 mse=0.002830
|
| 795 |
+
09125 mse=0.015277
|
| 796 |
+
07798 mse=0.017547
|
| 797 |
+
16847 mse=0.000129
|
| 798 |
+
02321 mse=0.005903
|
| 799 |
+
10827 mse=0.004561
|
| 800 |
+
01029 mse=0.005611
|
| 801 |
+
02208 mse=0.003085
|
| 802 |
+
07501 mse=0.005096
|
| 803 |
+
01929 mse=0.002922
|
| 804 |
+
02326 mse=0.025720
|
| 805 |
+
00235 mse=0.003902
|
| 806 |
+
07216 mse=0.004973
|
| 807 |
+
19282 mse=0.003903
|
| 808 |
+
07540 mse=0.010510
|
| 809 |
+
18150 mse=0.015341
|
| 810 |
+
19277 mse=0.006658
|
| 811 |
+
00689 mse=0.001561
|
| 812 |
+
11078 mse=0.008348
|
| 813 |
+
02098 mse=0.002378
|
| 814 |
+
07362 mse=0.013275
|
| 815 |
+
08123 mse=0.017092
|
| 816 |
+
03965 mse=0.005523
|
| 817 |
+
14794 mse=0.003243
|
| 818 |
+
16959 mse=0.014355
|
| 819 |
+
14355 mse=0.003741
|
| 820 |
+
12236 mse=0.004858
|
| 821 |
+
13074 mse=0.003443
|
| 822 |
+
10214 mse=0.004258
|
| 823 |
+
06591 mse=0.003775
|
| 824 |
+
17671 mse=0.010929
|
| 825 |
+
06940 mse=0.002097
|
| 826 |
+
13865 mse=0.005914
|
| 827 |
+
17290 mse=0.006832
|
| 828 |
+
08685 mse=0.016334
|
| 829 |
+
05410 mse=0.008638
|
| 830 |
+
18015 mse=0.009019
|
| 831 |
+
15574 mse=0.011938
|
| 832 |
+
04207 mse=0.006514
|
| 833 |
+
04120 mse=0.007746
|
| 834 |
+
17455 mse=0.008978
|
| 835 |
+
02267 mse=0.011783
|
| 836 |
+
15924 mse=0.008287
|
| 837 |
+
02806 mse=0.005054
|
| 838 |
+
02580 mse=0.003646
|
| 839 |
+
18591 mse=0.005672
|
| 840 |
+
07892 mse=0.003923
|
| 841 |
+
01898 mse=0.004494
|
| 842 |
+
07846 mse=0.002609
|
| 843 |
+
10076 mse=0.005584
|
| 844 |
+
11894 mse=0.010259
|
| 845 |
+
03665 mse=0.009702
|
| 846 |
+
00638 mse=0.004971
|
| 847 |
+
16010 mse=0.004073
|
| 848 |
+
10621 mse=0.004637
|
| 849 |
+
19626 mse=0.007213
|
| 850 |
+
00018 mse=0.006267
|
| 851 |
+
17378 mse=0.015393
|
| 852 |
+
17674 mse=0.002557
|
| 853 |
+
05293 mse=0.014019
|
| 854 |
+
12252 mse=0.008291
|
| 855 |
+
19885 mse=0.017402
|
| 856 |
+
06517 mse=0.008035
|
| 857 |
+
19859 mse=0.004822
|
| 858 |
+
09779 mse=0.002638
|
| 859 |
+
03486 mse=0.003352
|
| 860 |
+
16635 mse=0.007544
|
| 861 |
+
05854 mse=0.007090
|
| 862 |
+
16403 mse=0.001594
|
| 863 |
+
08630 mse=0.002533
|
| 864 |
+
00106 mse=0.012805
|
| 865 |
+
14867 mse=0.007520
|
| 866 |
+
05182 mse=0.009494
|
| 867 |
+
14246 mse=0.013231
|
| 868 |
+
09617 mse=0.001610
|
| 869 |
+
03655 mse=0.018741
|
| 870 |
+
11146 mse=0.005865
|
| 871 |
+
08743 mse=0.001506
|
| 872 |
+
17595 mse=0.005043
|
| 873 |
+
03753 mse=0.015660
|
| 874 |
+
00376 mse=0.000177
|
| 875 |
+
18128 mse=0.011650
|
| 876 |
+
08238 mse=0.006696
|
| 877 |
+
17338 mse=0.012221
|
| 878 |
+
15337 mse=0.005766
|
| 879 |
+
19526 mse=0.010126
|
| 880 |
+
12504 mse=0.001161
|
| 881 |
+
12608 mse=0.005191
|
| 882 |
+
02081 mse=0.014520
|
| 883 |
+
19543 mse=0.006403
|
| 884 |
+
13833 mse=0.010069
|
| 885 |
+
05242 mse=0.008055
|
| 886 |
+
05008 mse=0.005635
|
| 887 |
+
03592 mse=0.004524
|
| 888 |
+
01543 mse=0.011988
|
| 889 |
+
02978 mse=0.005750
|
| 890 |
+
16171 mse=0.007243
|
| 891 |
+
16696 mse=0.011489
|
| 892 |
+
04532 mse=0.004631
|
| 893 |
+
07186 mse=0.009895
|
| 894 |
+
11861 mse=0.005365
|
| 895 |
+
13087 mse=0.004986
|
| 896 |
+
19121 mse=0.010848
|
| 897 |
+
14038 mse=0.003433
|
| 898 |
+
19155 mse=0.008910
|
| 899 |
+
08609 mse=0.001800
|
| 900 |
+
17661 mse=0.007982
|
| 901 |
+
18394 mse=0.006856
|
| 902 |
+
08081 mse=0.016075
|
| 903 |
+
04575 mse=0.010351
|
| 904 |
+
08679 mse=0.003621
|
| 905 |
+
04681 mse=0.000964
|
| 906 |
+
15035 mse=0.012249
|
| 907 |
+
12964 mse=0.005834
|
| 908 |
+
16358 mse=0.004575
|
| 909 |
+
06967 mse=0.028753
|
| 910 |
+
10311 mse=0.004275
|
| 911 |
+
18585 mse=0.014650
|
| 912 |
+
06913 mse=0.008484
|
| 913 |
+
02168 mse=0.015948
|
| 914 |
+
08773 mse=0.002038
|
| 915 |
+
13145 mse=0.003065
|
| 916 |
+
10336 mse=0.017272
|
| 917 |
+
01051 mse=0.008510
|
| 918 |
+
07505 mse=0.003368
|
| 919 |
+
01832 mse=0.009042
|
| 920 |
+
10626 mse=0.000839
|
| 921 |
+
07196 mse=0.008562
|
| 922 |
+
18250 mse=0.001764
|
| 923 |
+
08845 mse=0.000329
|
| 924 |
+
12433 mse=0.000763
|
| 925 |
+
15147 mse=0.011794
|
| 926 |
+
05354 mse=0.001267
|
| 927 |
+
08021 mse=0.010643
|
| 928 |
+
17502 mse=0.002931
|
| 929 |
+
05607 mse=0.004029
|
| 930 |
+
02339 mse=0.002522
|
| 931 |
+
08748 mse=0.000439
|
| 932 |
+
06865 mse=0.002523
|
| 933 |
+
11641 mse=0.008693
|
| 934 |
+
12130 mse=0.005906
|
| 935 |
+
05329 mse=0.002781
|
| 936 |
+
11954 mse=0.003806
|
| 937 |
+
14857 mse=0.004294
|
| 938 |
+
09108 mse=0.014271
|
| 939 |
+
12455 mse=0.000913
|
| 940 |
+
03309 mse=0.005440
|
| 941 |
+
07628 mse=0.005613
|
| 942 |
+
02614 mse=0.002911
|
| 943 |
+
09482 mse=0.008526
|
| 944 |
+
07467 mse=0.006411
|
| 945 |
+
01501 mse=0.002836
|
| 946 |
+
02279 mse=0.006905
|
| 947 |
+
06300 mse=0.012404
|
| 948 |
+
18918 mse=0.005894
|
| 949 |
+
11850 mse=0.005871
|
| 950 |
+
09606 mse=0.002263
|
| 951 |
+
18089 mse=0.017444
|
| 952 |
+
02582 mse=0.003119
|
| 953 |
+
12403 mse=0.000606
|
| 954 |
+
04090 mse=0.003476
|
| 955 |
+
17571 mse=0.002986
|
| 956 |
+
15054 mse=0.006459
|
| 957 |
+
01423 mse=0.003188
|
| 958 |
+
08667 mse=0.008596
|
| 959 |
+
19782 mse=0.019427
|
| 960 |
+
11270 mse=0.006343
|
| 961 |
+
11763 mse=0.005554
|
| 962 |
+
03169 mse=0.016593
|
| 963 |
+
12449 mse=0.000495
|
| 964 |
+
03039 mse=0.002252
|
| 965 |
+
03349 mse=0.010690
|
| 966 |
+
11029 mse=0.004264
|
| 967 |
+
07055 mse=0.003639
|
| 968 |
+
05094 mse=0.008496
|
| 969 |
+
09105 mse=0.002958
|
| 970 |
+
11149 mse=0.010304
|
| 971 |
+
13848 mse=0.004117
|
| 972 |
+
05231 mse=0.030994
|
| 973 |
+
00212 mse=0.007919
|
| 974 |
+
09115 mse=0.005759
|
| 975 |
+
19309 mse=0.017774
|
| 976 |
+
14719 mse=0.015124
|
| 977 |
+
07223 mse=0.005964
|
| 978 |
+
13746 mse=0.001556
|
| 979 |
+
17856 mse=0.005898
|
| 980 |
+
06515 mse=0.008227
|
| 981 |
+
18390 mse=0.007518
|
| 982 |
+
00869 mse=0.004310
|
| 983 |
+
19726 mse=0.015906
|
| 984 |
+
16559 mse=0.003382
|
| 985 |
+
07623 mse=0.011249
|
| 986 |
+
07164 mse=0.016578
|
| 987 |
+
03070 mse=0.003559
|
| 988 |
+
00976 mse=0.002195
|
| 989 |
+
01041 mse=0.007530
|
| 990 |
+
13825 mse=0.003102
|
| 991 |
+
19349 mse=0.017139
|
| 992 |
+
02848 mse=0.013111
|
| 993 |
+
17870 mse=0.004815
|
| 994 |
+
03358 mse=0.013159
|
| 995 |
+
04572 mse=0.011368
|
| 996 |
+
07314 mse=0.004934
|
| 997 |
+
08024 mse=0.005998
|
| 998 |
+
09012 mse=0.019656
|
| 999 |
+
00819 mse=0.004931
|
| 1000 |
+
03648 mse=0.009436
|
| 1001 |
+
mean_mse 0.00672006 n=1000
|
requirements_baseline.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0
|
| 2 |
+
torchvision>=0.15
|
| 3 |
+
numpy>=1.22
|
| 4 |
+
opencv-python>=4.5
|
| 5 |
+
h5py>=3.8
|
| 6 |
+
albumentations>=1.3.0
|
requirements_uvdoc_train.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# UVDoc 官方训练 + 本地 baseline 共用(已用 conda env uvdoc 验证可装)
|
| 2 |
+
# 使用: conda activate uvdoc && pip install -r requirements_uvdoc_train.txt
|
| 3 |
+
|
| 4 |
+
torch>=2.0
|
| 5 |
+
torchvision>=0.15
|
| 6 |
+
numpy>=1.22
|
| 7 |
+
h5py>=3.8
|
| 8 |
+
opencv-python-headless>=4.7
|
| 9 |
+
albumentations>=1.3
|
run_overfit_official_uvdoc.sh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Single-sample overfit on UVDoc_final with OFFICIAL default hyperparameters:
|
| 3 |
+
# UVDocnet, lr=2e-4, batch=8, n_epochs=10, n_epochs_decay=10,
|
| 4 |
+
# alpha=beta=5, gamma=1, ep_gamma_start=10.
|
| 5 |
+
# Overfit branch uses deterministic crop + no aug (matches verify_ckpt_val_pipeline.py).
|
| 6 |
+
|
| 7 |
+
set -euo pipefail
|
| 8 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 9 |
+
PY="${PYTHON:-/root/miniconda3/envs/o3dedit/bin/python}"
|
| 10 |
+
LOGDIR="${LOGDIR:-$ROOT/log_overfit_official_uvdoc}"
|
| 11 |
+
UV="${UV_DOC_ROOT:-$ROOT/UVDoc_final}"
|
| 12 |
+
|
| 13 |
+
exec "$PY" "$ROOT/UVDoc_official/train.py" \
|
| 14 |
+
--data_to_use uvdoc \
|
| 15 |
+
--data_path_UVDoc "$UV" \
|
| 16 |
+
--overfit_n 1 \
|
| 17 |
+
--batch_size 8 \
|
| 18 |
+
--n_epochs 10 \
|
| 19 |
+
--n_epochs_decay 10 \
|
| 20 |
+
--lr 0.0002 \
|
| 21 |
+
--alpha_w 5.0 \
|
| 22 |
+
--beta_w 5.0 \
|
| 23 |
+
--gamma_w 1.0 \
|
| 24 |
+
--ep_gamma_start 10 \
|
| 25 |
+
--num_workers "${NUM_WORKERS:-4}" \
|
| 26 |
+
--device "${DEVICE:-cuda:0}" \
|
| 27 |
+
--logdir "$LOGDIR"
|
run_overfit_train_infer_consistency.sh
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# 单样例过拟合 + 训练/推理数据管线一致性校验
|
| 3 |
+
#
|
| 4 |
+
# 1) 快速断言:train.py 与 verify_ckpt_val_pipeline.py 使用的 UVDocDataset 张量一致
|
| 5 |
+
# 2) 可选:单样例过拟合训练(与 run_overfit_official_uvdoc.sh 相同超参)
|
| 6 |
+
# 3) 用同一套预处理跑 verify_ckpt_val_pipeline.py,mean_mse 应与训练日志里该 epoch 的 Val MSE 对齐
|
| 7 |
+
#
|
| 8 |
+
# 用法:
|
| 9 |
+
# PREPROCESS_ONLY=1 ./run_overfit_train_infer_consistency.sh
|
| 10 |
+
# ./run_overfit_train_infer_consistency.sh
|
| 11 |
+
#
|
| 12 |
+
set -euo pipefail
|
| 13 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 14 |
+
PY="${PYTHON:-python3}"
|
| 15 |
+
UV="${UV_DOC_ROOT:-$ROOT/UVDoc_final}"
|
| 16 |
+
OFF="$ROOT/UVDoc_official"
|
| 17 |
+
LOGDIR="${LOGDIR:-$ROOT/log_overfit_consistency}"
|
| 18 |
+
CKPT_GLOB="${CKPT_GLOB:-}"
|
| 19 |
+
|
| 20 |
+
cd "$OFF"
|
| 21 |
+
|
| 22 |
+
echo "== (1) Preprocess alignment: train vs verify_ckpt constructors =="
|
| 23 |
+
"$PY" verify_uvdoc_train_infer_preprocess.py \
|
| 24 |
+
--data_path_UVDoc "$UV" \
|
| 25 |
+
--overfit_n 1 \
|
| 26 |
+
--mode overfit \
|
| 27 |
+
--check_dataloader \
|
| 28 |
+
--batch_size 8 \
|
| 29 |
+
--num_workers 0
|
| 30 |
+
|
| 31 |
+
if [[ "${PREPROCESS_ONLY:-0}" == "1" ]]; then
|
| 32 |
+
echo "PREPROCESS_ONLY=1, skip training and checkpoint verify."
|
| 33 |
+
exit 0
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
echo "== (2) Single-sample overfit training =="
|
| 37 |
+
"$PY" train.py \
|
| 38 |
+
--data_to_use uvdoc \
|
| 39 |
+
--data_path_UVDoc "$UV" \
|
| 40 |
+
--overfit_n 1 \
|
| 41 |
+
--batch_size 8 \
|
| 42 |
+
--n_epochs 10 \
|
| 43 |
+
--n_epochs_decay 10 \
|
| 44 |
+
--lr 0.0002 \
|
| 45 |
+
--alpha_w 5.0 \
|
| 46 |
+
--beta_w 5.0 \
|
| 47 |
+
--gamma_w 1.0 \
|
| 48 |
+
--ep_gamma_start 10 \
|
| 49 |
+
--num_workers "${NUM_WORKERS:-4}" \
|
| 50 |
+
--device "${DEVICE:-cuda:0}" \
|
| 51 |
+
--logdir "$LOGDIR"
|
| 52 |
+
|
| 53 |
+
# 取最新 best ckpt(按修改时间)
|
| 54 |
+
mapfile -t CKPTS < <(ls -t "$LOGDIR"/ep_*_best_model.pkl 2>/dev/null || true)
|
| 55 |
+
if [[ ${#CKPTS[@]} -eq 0 ]]; then
|
| 56 |
+
echo "No ep_*_best_model.pkl under $LOGDIR" >&2
|
| 57 |
+
exit 1
|
| 58 |
+
fi
|
| 59 |
+
CKPT="${CKPT_GLOB:-${CKPTS[0]}}"
|
| 60 |
+
echo "Using checkpoint: $CKPT"
|
| 61 |
+
|
| 62 |
+
OUT="$LOGDIR/verify_infer_same_preprocess"
|
| 63 |
+
rm -rf "$OUT"
|
| 64 |
+
mkdir -p "$OUT"
|
| 65 |
+
|
| 66 |
+
echo "== (3) Inference with SAME dataset kwargs as train val/overfit =="
|
| 67 |
+
"$PY" verify_ckpt_val_pipeline.py \
|
| 68 |
+
--ckpt "$CKPT" \
|
| 69 |
+
--data_path_UVDoc "$UV" \
|
| 70 |
+
--overfit_n 1 \
|
| 71 |
+
--out_dir "$OUT" \
|
| 72 |
+
--max_save_images 1 \
|
| 73 |
+
--device "${DEVICE:-cuda:0}"
|
| 74 |
+
|
| 75 |
+
echo "Done. Compare mean_mse in $OUT/metrics.txt to the Val MSE line in train log under $LOGDIR"
|
run_train_full_uvdoc_gpu0.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Full UVDoc training (no Doc3D), GPU 0, recommended hyperparameters.
|
| 3 |
+
# Stop with: kill <pid> (or Ctrl+C if foreground)
|
| 4 |
+
|
| 5 |
+
set -euo pipefail
|
| 6 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 7 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
|
| 8 |
+
PY="${PYTHON:-/root/miniconda3/envs/o3dedit/bin/python}"
|
| 9 |
+
|
| 10 |
+
UV_ROOT="${UV_ROOT:-$ROOT/UVDoc_final}"
|
| 11 |
+
LOGDIR="${LOGDIR:-$ROOT/log_full_uvdoc_gpu0}"
|
| 12 |
+
|
| 13 |
+
# Shorter uvdoc-only schedule (override with N_EPOCHS, N_DECAY env vars).
|
| 14 |
+
BS="${BS:-8}"
|
| 15 |
+
N_EPOCHS="${N_EPOCHS:-25}"
|
| 16 |
+
N_DECAY="${N_DECAY:-10}"
|
| 17 |
+
LR="${LR:-0.0002}"
|
| 18 |
+
EP_GAMMA="${EP_GAMMA:-10}"
|
| 19 |
+
VAL_RATIO="${VAL_RATIO:-0.05}"
|
| 20 |
+
SPLIT_SEED="${SPLIT_SEED:-42}"
|
| 21 |
+
NUM_WORKERS="${NUM_WORKERS:-8}"
|
| 22 |
+
|
| 23 |
+
exec "$PY" "$ROOT/UVDoc_official/train.py" \
|
| 24 |
+
--data_to_use uvdoc \
|
| 25 |
+
--data_path_UVDoc "$UV_ROOT" \
|
| 26 |
+
--uvdoc_val_ratio "$VAL_RATIO" \
|
| 27 |
+
--uvdoc_split_seed "$SPLIT_SEED" \
|
| 28 |
+
--batch_size "$BS" \
|
| 29 |
+
--n_epochs "$N_EPOCHS" \
|
| 30 |
+
--n_epochs_decay "$N_DECAY" \
|
| 31 |
+
--lr "$LR" \
|
| 32 |
+
--alpha_w 5.0 \
|
| 33 |
+
--beta_w 5.0 \
|
| 34 |
+
--gamma_w 1.0 \
|
| 35 |
+
--ep_gamma_start "$EP_GAMMA" \
|
| 36 |
+
--appearance_augmentation visual noise color \
|
| 37 |
+
--geometric_augmentationsUVDoc rotate \
|
| 38 |
+
--num_workers "$NUM_WORKERS" \
|
| 39 |
+
--device cuda:0 \
|
| 40 |
+
--logdir "$LOGDIR"
|
run_train_official_config.sh
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# UVDoc official-style training (matches UVDoc_official/train.py defaults).
|
| 3 |
+
#
|
| 4 |
+
# Mode A — paper/repo default: Doc3D + UVDoc mixed (--data_to_use both).
|
| 5 |
+
# Set DOC3D_ROOT to your Doc3D dataset; UVDoc to UVDoc_final (or official layout).
|
| 6 |
+
#
|
| 7 |
+
# Mode B — local UVDoc only: same LR/epoch/gamma schedule as official, no Doc3D.
|
| 8 |
+
# export TRAIN_MODE=uvdoc_only
|
| 9 |
+
#
|
| 10 |
+
# Defaults (official argparse):
|
| 11 |
+
# batch_size=8, n_epochs=10, n_epochs_decay=10, lr=2e-4,
|
| 12 |
+
# alpha_w=5 beta_w=5 gamma_w=1 ep_gamma_start=10,
|
| 13 |
+
# appearance: visual noise color, UVDoc geom: rotate
|
| 14 |
+
|
| 15 |
+
set -euo pipefail
|
| 16 |
+
|
| 17 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 18 |
+
PY="${PYTHON:-python3}"
|
| 19 |
+
TRAIN="${ROOT}/UVDoc_official/train.py"
|
| 20 |
+
|
| 21 |
+
# --- edit paths ---
|
| 22 |
+
DOC3D_ROOT="${DOC3D_ROOT:-/path/to/data/doc3D}"
|
| 23 |
+
UV_DOC_ROOT="${UV_DOC_ROOT:-/mnt/zsn/zsn_workspace/dzx/UvDoc/UVDoc_final}"
|
| 24 |
+
LOGDIR="${LOGDIR:-${ROOT}/log/official_default}"
|
| 25 |
+
|
| 26 |
+
# TRAIN_MODE: both | uvdoc_only
|
| 27 |
+
TRAIN_MODE="${TRAIN_MODE:-both}"
|
| 28 |
+
|
| 29 |
+
# Official hyperparameters (explicit for clarity)
|
| 30 |
+
BS="${BS:-8}"
|
| 31 |
+
N_EPOCHS="${N_EPOCHS:-10}"
|
| 32 |
+
N_EPOCHS_DECAY="${N_EPOCHS_DECAY:-10}"
|
| 33 |
+
LR="${LR:-0.0002}"
|
| 34 |
+
ALPHA="${ALPHA:-5.0}"
|
| 35 |
+
BETA="${BETA:-5.0}"
|
| 36 |
+
GAMMA="${GAMMA:-1.0}"
|
| 37 |
+
EP_GAMMA_START="${EP_GAMMA_START:-10}"
|
| 38 |
+
UV_VAL_RATIO="${UV_VAL_RATIO:-0.05}"
|
| 39 |
+
UV_SPLIT_SEED="${UV_SPLIT_SEED:-42}"
|
| 40 |
+
NUM_WORKERS="${NUM_WORKERS:-8}"
|
| 41 |
+
DEVICE="${DEVICE:-cuda:0}"
|
| 42 |
+
|
| 43 |
+
common_args=(
|
| 44 |
+
--batch_size "$BS"
|
| 45 |
+
--n_epochs "$N_EPOCHS"
|
| 46 |
+
--n_epochs_decay "$N_EPOCHS_DECAY"
|
| 47 |
+
--lr "$LR"
|
| 48 |
+
--alpha_w "$ALPHA"
|
| 49 |
+
--beta_w "$BETA"
|
| 50 |
+
--gamma_w "$GAMMA"
|
| 51 |
+
--ep_gamma_start "$EP_GAMMA_START"
|
| 52 |
+
--appearance_augmentation visual noise color
|
| 53 |
+
--geometric_augmentationsUVDoc rotate
|
| 54 |
+
--num_workers "$NUM_WORKERS"
|
| 55 |
+
--device "$DEVICE"
|
| 56 |
+
--logdir "$LOGDIR"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if [[ "$TRAIN_MODE" == "both" ]]; then
|
| 60 |
+
if [[ ! -d "$DOC3D_ROOT" ]]; then
|
| 61 |
+
echo "ERROR: DOC3D_ROOT is not a directory: $DOC3D_ROOT"
|
| 62 |
+
echo "Set DOC3D_ROOT to your Doc3D dataset root, or use TRAIN_MODE=uvdoc_only."
|
| 63 |
+
exit 1
|
| 64 |
+
fi
|
| 65 |
+
exec "$PY" "$TRAIN" \
|
| 66 |
+
--data_to_use both \
|
| 67 |
+
--data_path_doc3D "$DOC3D_ROOT" \
|
| 68 |
+
--data_path_UVDoc "$UV_DOC_ROOT" \
|
| 69 |
+
"${common_args[@]}"
|
| 70 |
+
elif [[ "$TRAIN_MODE" == "uvdoc_only" ]]; then
|
| 71 |
+
exec "$PY" "$TRAIN" \
|
| 72 |
+
--data_to_use uvdoc \
|
| 73 |
+
--data_path_UVDoc "$UV_DOC_ROOT" \
|
| 74 |
+
--uvdoc_val_ratio "$UV_VAL_RATIO" \
|
| 75 |
+
--uvdoc_split_seed "$UV_SPLIT_SEED" \
|
| 76 |
+
"${common_args[@]}"
|
| 77 |
+
else
|
| 78 |
+
echo "TRAIN_MODE must be 'both' or 'uvdoc_only', got: $TRAIN_MODE"
|
| 79 |
+
exit 1
|
| 80 |
+
fi
|
run_train_uvdoc_baseline.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Entrypoint: run from UvDoc directory: python run_train_uvdoc_baseline.py --data_root ./UVDoc_final"""
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
if __name__ == "__main__":
|
| 8 |
+
here = Path(__file__).resolve().parent
|
| 9 |
+
sys.path.insert(0, str(here))
|
| 10 |
+
from baseline_resnet_unet.train import main
|
| 11 |
+
main()
|
unzip_extract.log
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
caution: excluded filename not matched: */__MACOSX/*
|
uvdoc_文档矫正_colab_技术路线(gemini_可执行版).md
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# UVDoc 文档矫正技术路线(Colab + Gemini 实现指南)
|
| 2 |
+
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# 🎯 目标
|
| 6 |
+
|
| 7 |
+
在 Colab 上实现一个可训练、可验证的文档矫正模型:
|
| 8 |
+
- 输入:畸变文档图像
|
| 9 |
+
- 输出:UV map(H×W×2)
|
| 10 |
+
- 使用:grid_sample 生成矫正图
|
| 11 |
+
- 在 UVDoc benchmark 上评估性能
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# 🧠 核心思路
|
| 16 |
+
|
| 17 |
+
本任务本质是:
|
| 18 |
+
|
| 19 |
+
> Dense per-pixel mapping(UV 映射预测)
|
| 20 |
+
|
| 21 |
+
模型学习:
|
| 22 |
+
|
| 23 |
+
f(I) → UV
|
| 24 |
+
|
| 25 |
+
然后:
|
| 26 |
+
|
| 27 |
+
I_rectified = grid_sample(I, UV)
|
| 28 |
+
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
# 🏗️ 模型结构(推荐实现)
|
| 32 |
+
|
| 33 |
+
## Backbone
|
| 34 |
+
- ResNet50(ImageNet 预训练)
|
| 35 |
+
|
| 36 |
+
## Decoder
|
| 37 |
+
- U-Net 风格(带 skip connection)
|
| 38 |
+
|
| 39 |
+
## 输出
|
| 40 |
+
- UV map: [B, H, W, 2]
|
| 41 |
+
- 值域:[-1, 1]
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
# ⚙️ 训练流程
|
| 46 |
+
|
| 47 |
+
## Step 1:加载预训练 backbone
|
| 48 |
+
- 使用 torchvision ResNet50
|
| 49 |
+
|
| 50 |
+
## Step 2:构建 U-Net Decoder
|
| 51 |
+
- 上采样 + skip connection
|
| 52 |
+
|
| 53 |
+
## Step 3:输出 UV map
|
| 54 |
+
- 最后一层用 Tanh
|
| 55 |
+
|
| 56 |
+
## Step 4:warp 图像
|
| 57 |
+
|
| 58 |
+
使用:
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
F.grid_sample(input, UV_pred)
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
---
|
| 65 |
+
|
| 66 |
+
# 📉 Loss 设计
|
| 67 |
+
|
| 68 |
+
## 1. UV Loss(核心)
|
| 69 |
+
|
| 70 |
+
L_uv = |UV_pred - UV_gt|
|
| 71 |
+
|
| 72 |
+
## 2. Image Loss
|
| 73 |
+
|
| 74 |
+
L_img = |I_pred - I_gt|
|
| 75 |
+
|
| 76 |
+
## 3. Perceptual Loss(可选)
|
| 77 |
+
|
| 78 |
+
使用 VGG feature
|
| 79 |
+
|
| 80 |
+
## 最终 Loss
|
| 81 |
+
|
| 82 |
+
L = L_uv + 1.0 * L_img + 0.1 * L_perc
|
| 83 |
+
|
| 84 |
+
---
|
| 85 |
+
|
| 86 |
+
# 📦 数据准备
|
| 87 |
+
|
| 88 |
+
## UVDoc 数据
|
| 89 |
+
|
| 90 |
+
需要包含:
|
| 91 |
+
- input image
|
| 92 |
+
- GT UV map
|
| 93 |
+
- GT rectified image(可选)
|
| 94 |
+
|
| 95 |
+
## 数据预处理
|
| 96 |
+
|
| 97 |
+
- resize 到 256×256
|
| 98 |
+
- normalize
|
| 99 |
+
|
| 100 |
+
---
|
| 101 |
+
|
| 102 |
+
# 🚀 Colab 配置
|
| 103 |
+
|
| 104 |
+
## 推荐设置
|
| 105 |
+
|
| 106 |
+
- GPU: T4 / A100
|
| 107 |
+
- batch size: 4~8
|
| 108 |
+
- epoch: 50+
|
| 109 |
+
- optimizer: AdamW
|
| 110 |
+
- lr: 1e-4
|
| 111 |
+
|
| 112 |
+
---
|
| 113 |
+
|
| 114 |
+
# ⚡ 性能优化
|
| 115 |
+
|
| 116 |
+
## 必做
|
| 117 |
+
|
| 118 |
+
- mixed precision(torch.cuda.amp)
|
| 119 |
+
- gradient accumulation
|
| 120 |
+
|
| 121 |
+
## 推荐
|
| 122 |
+
|
| 123 |
+
- 随机 crop
|
| 124 |
+
- 数据增强(亮度、对比度)
|
| 125 |
+
|
| 126 |
+
---
|
| 127 |
+
|
| 128 |
+
# 📊 评估指标
|
| 129 |
+
|
| 130 |
+
- L1 UV error
|
| 131 |
+
- PSNR
|
| 132 |
+
- SSIM
|
| 133 |
+
|
| 134 |
+
---
|
| 135 |
+
|
| 136 |
+
# 🧪 训练策略(强烈推荐)
|
| 137 |
+
|
| 138 |
+
## Phase 1(可选)
|
| 139 |
+
|
| 140 |
+
Synthetic 数据预训练:
|
| 141 |
+
- 使用 TPS 生成畸变
|
| 142 |
+
- 自动生成 UV GT
|
| 143 |
+
|
| 144 |
+
## Phase 2
|
| 145 |
+
|
| 146 |
+
在 UVDoc 上 finetune
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
# ⚠️ 常见错误
|
| 151 |
+
|
| 152 |
+
- ❌ UV 坐标方向错误(forward/backward 混淆)
|
| 153 |
+
- ❌ 未使用 grid_sample
|
| 154 |
+
- ❌ 直接预测图像
|
| 155 |
+
- ❌ 分辨率过大导致 OOM
|
| 156 |
+
|
| 157 |
+
---
|
| 158 |
+
|
| 159 |
+
# 🧾 Gemini Prompt(直接可用)
|
| 160 |
+
|
| 161 |
+
将下面内容复制给 Gemini:
|
| 162 |
+
|
| 163 |
+
---
|
| 164 |
+
|
| 165 |
+
你需要在 Google Colab 上实现一个文档矫正模型,要求如下:
|
| 166 |
+
|
| 167 |
+
1. 使用 PyTorch
|
| 168 |
+
2. Backbone: ResNet50(ImageNet 预训练)
|
| 169 |
+
3. Decoder: U-Net
|
| 170 |
+
4. 输出:UV map(H×W×2,范围 [-1,1])
|
| 171 |
+
5. 使用 torch.nn.functional.grid_sample 生成矫正图像
|
| 172 |
+
|
| 173 |
+
训练部分:
|
| 174 |
+
- Loss = L1(UV) + L1(image)
|
| 175 |
+
- optimizer: AdamW
|
| 176 |
+
- 使用 mixed precision
|
| 177 |
+
|
| 178 |
+
数据:
|
| 179 |
+
- 输入图像
|
| 180 |
+
- UV GT
|
| 181 |
+
|
| 182 |
+
要求:
|
| 183 |
+
- 提供完整训练代码
|
| 184 |
+
- 包含 model、dataset、train loop
|
| 185 |
+
- 可直接在 Colab 运行
|
| 186 |
+
|
| 187 |
+
---
|
| 188 |
+
|
| 189 |
+
# ✅ 最终效果
|
| 190 |
+
|
| 191 |
+
你应该能得到:
|
| 192 |
+
|
| 193 |
+
- 一个可训练模型
|
| 194 |
+
- 输出 UV map
|
| 195 |
+
- 可视化矫正结果
|
| 196 |
+
- 在 UVDoc 上评估
|
| 197 |
+
|
| 198 |
+
---
|
| 199 |
+
|
| 200 |
+
# 📌 备注
|
| 201 |
+
|
| 202 |
+
如果效果不佳,可以升级:
|
| 203 |
+
- backbone → Swin-T
|
| 204 |
+
- 加 perceptual loss
|
| 205 |
+
- 使用 multi-scale
|
| 206 |
+
|
| 207 |
+
---
|
| 208 |
+
|
| 209 |
+
# 🎯 一句话总结
|
| 210 |
+
|
| 211 |
+
用 ResNet + U-Net 预测 UV map,再用 grid_sample 重建图像,这是 UVDoc 最稳的 baseline 方案。
|
| 212 |
+
|